mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Add a 'Count tokens' button
This commit is contained in:
parent
a6ef2429fa
commit
c238ba9532
28
server.py
28
server.py
@ -1,7 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import requests
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
||||||
os.environ['BITSANDBYTES_NOWELCOME'] = '1'
|
os.environ['BITSANDBYTES_NOWELCOME'] = '1'
|
||||||
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
||||||
@ -15,10 +16,12 @@ def my_get(url, **kwargs):
|
|||||||
original_get = requests.get
|
original_get = requests.get
|
||||||
requests.get = my_get
|
requests.get = my_get
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
requests.get = original_get
|
requests.get = original_get
|
||||||
|
|
||||||
# This fixes LaTeX rendering on some systems
|
# This fixes LaTeX rendering on some systems
|
||||||
import matplotlib
|
import matplotlib
|
||||||
|
|
||||||
matplotlib.use('Agg')
|
matplotlib.use('Agg')
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
@ -44,7 +47,8 @@ from modules import api, chat, shared, training, ui
|
|||||||
from modules.html_generator import chat_html_wrapper
|
from modules.html_generator import chat_html_wrapper
|
||||||
from modules.LoRA import add_lora_to_model
|
from modules.LoRA import add_lora_to_model
|
||||||
from modules.models import load_model, load_soft_prompt, unload_model
|
from modules.models import load_model, load_soft_prompt, unload_model
|
||||||
from modules.text_generation import generate_reply, stop_everything_event
|
from modules.text_generation import (encode, generate_reply,
|
||||||
|
stop_everything_event)
|
||||||
|
|
||||||
|
|
||||||
def get_available_models():
|
def get_available_models():
|
||||||
@ -172,6 +176,11 @@ def load_prompt(fname):
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def count_tokens(text):
|
||||||
|
tokens = len(encode(text)[0])
|
||||||
|
return f'{tokens} tokens in the input.'
|
||||||
|
|
||||||
|
|
||||||
def download_model_wrapper(repo_id):
|
def download_model_wrapper(repo_id):
|
||||||
try:
|
try:
|
||||||
downloader = importlib.import_module("download-model")
|
downloader = importlib.import_module("download-model")
|
||||||
@ -628,6 +637,7 @@ def create_interface():
|
|||||||
ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': get_available_prompts()}, 'refresh-button')
|
ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': get_available_prompts()}, 'refresh-button')
|
||||||
|
|
||||||
shared.gradio['save_prompt'] = gr.Button('Save prompt')
|
shared.gradio['save_prompt'] = gr.Button('Save prompt')
|
||||||
|
shared.gradio['count_tokens'] = gr.Button('Count tokens')
|
||||||
shared.gradio['status'] = gr.Markdown('')
|
shared.gradio['status'] = gr.Markdown('')
|
||||||
|
|
||||||
with gr.Tab("Parameters", elem_id="parameters"):
|
with gr.Tab("Parameters", elem_id="parameters"):
|
||||||
@ -644,10 +654,11 @@ def create_interface():
|
|||||||
shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_classes="textbox_default", lines=27, label='Input')
|
shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_classes="textbox_default", lines=27, label='Input')
|
||||||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['Generate'] = gr.Button('Generate', variant='primary')
|
shared.gradio['Generate'] = gr.Button('Generate', variant='primary', elem_classes="small-button")
|
||||||
shared.gradio['Stop'] = gr.Button('Stop')
|
shared.gradio['Stop'] = gr.Button('Stop', elem_classes="small-button")
|
||||||
shared.gradio['Continue'] = gr.Button('Continue')
|
shared.gradio['Continue'] = gr.Button('Continue', elem_classes="small-button")
|
||||||
shared.gradio['save_prompt'] = gr.Button('Save prompt')
|
shared.gradio['save_prompt'] = gr.Button('Save prompt', elem_classes="small-button")
|
||||||
|
shared.gradio['count_tokens'] = gr.Button('Count tokens', elem_classes="small-button")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@ -838,8 +849,9 @@ def create_interface():
|
|||||||
)
|
)
|
||||||
|
|
||||||
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||||
shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
|
shared.gradio['prompt_menu'].change(load_prompt, shared.gradio['prompt_menu'], shared.gradio['textbox'], show_progress=False)
|
||||||
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
|
shared.gradio['save_prompt'].click(save_prompt, shared.gradio['textbox'], shared.gradio['status'], show_progress=False)
|
||||||
|
shared.gradio['count_tokens'].click(count_tokens, shared.gradio['textbox'], shared.gradio['status'], show_progress=False)
|
||||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||||
|
|
||||||
# Launch the interface
|
# Launch the interface
|
||||||
|
Loading…
Reference in New Issue
Block a user