From ef04138bc0379cac4c0d19b1a5190be95f0be5d0 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 15 Sep 2023 19:30:44 -0700 Subject: [PATCH] Improve the UI tokenizer --- css/main.css | 4 ++++ modules/text_generation.py | 11 +++++++++++ modules/ui_default.py | 6 +++--- modules/ui_notebook.py | 6 +++--- 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/css/main.css b/css/main.css index f7641f11..69262c90 100644 --- a/css/main.css +++ b/css/main.css @@ -127,6 +127,10 @@ div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * { height: calc(100dvh - 292px); } +.monospace { + font-family: monospace; +} + .textbox_default textarea, .textbox_default_output textarea, .textbox_logits textarea, diff --git a/modules/text_generation.py b/modules/text_generation.py index 98682bb2..a3755c10 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -144,6 +144,17 @@ def get_encoded_length(prompt): return len(encode(prompt)[0]) +def get_token_ids(prompt): + tokens = encode(prompt)[0] + decoded_tokens = [shared.tokenizer.decode(i) for i in tokens] + + output = '' + for row in list(zip(tokens, decoded_tokens)): + output += f"{str(int(row[0])).ljust(5)} - {row[1]}\n" + + return output + + def get_max_prompt_length(state): return state['truncation_length'] - state['max_new_tokens'] diff --git a/modules/ui_default.py b/modules/ui_default.py index 38056701..7357094b 100644 --- a/modules/ui_default.py +++ b/modules/ui_default.py @@ -3,8 +3,8 @@ import gradio as gr from modules import logits, shared, ui, utils from modules.prompts import count_tokens, load_prompt from modules.text_generation import ( - encode, generate_reply_wrapper, + get_token_ids, stop_everything_event ) from modules.utils import gradio @@ -57,7 +57,7 @@ def create_ui(): with gr.Tab('Tokens'): shared.gradio['get_tokens-default'] = gr.Button('Get token IDs for the input') - shared.gradio['tokens-default'] = gr.Textbox(lines=23, label='Tokens', elem_classes=['textbox_logits', 'add_scrollbar']) + shared.gradio['tokens-default'] = gr.Textbox(lines=23, label='Tokens', elem_classes=['textbox_logits', 'add_scrollbar', 'monospace']) def create_event_handlers(): @@ -100,4 +100,4 @@ def create_event_handlers(): ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( logits.get_next_logits, gradio('textbox-default', 'interface_state', 'use_samplers-default', 'logits-default'), gradio('logits-default', 'logits-default-previous'), show_progress=False) - shared.gradio['get_tokens-default'].click(lambda x : str(encode(x)[0].tolist()), gradio('textbox-default'), gradio('tokens-default'), show_progress=False) + shared.gradio['get_tokens-default'].click(get_token_ids, gradio('textbox-default'), gradio('tokens-default'), show_progress=False) diff --git a/modules/ui_notebook.py b/modules/ui_notebook.py index 269de3ec..60e3ee4e 100644 --- a/modules/ui_notebook.py +++ b/modules/ui_notebook.py @@ -3,8 +3,8 @@ import gradio as gr from modules import logits, shared, ui, utils from modules.prompts import count_tokens, load_prompt from modules.text_generation import ( - encode, generate_reply_wrapper, + get_token_ids, stop_everything_event ) from modules.utils import gradio @@ -43,7 +43,7 @@ def create_ui(): with gr.Tab('Tokens'): shared.gradio['get_tokens-notebook'] = gr.Button('Get token IDs for the input') - shared.gradio['tokens-notebook'] = gr.Textbox(lines=23, label='Tokens', elem_classes=['textbox_logits_notebook', 'add_scrollbar']) + shared.gradio['tokens-notebook'] = gr.Textbox(lines=23, label='Tokens', elem_classes=['textbox_logits_notebook', 'add_scrollbar', 'monospace']) with gr.Row(): shared.gradio['Generate-notebook'] = gr.Button('Generate', variant='primary', elem_classes='small-button') @@ -102,4 +102,4 @@ def create_event_handlers(): ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( logits.get_next_logits, gradio('textbox-notebook', 'interface_state', 'use_samplers-notebook', 'logits-notebook'), gradio('logits-notebook', 'logits-notebook-previous'), show_progress=False) - shared.gradio['get_tokens-notebook'].click(lambda x : str(encode(x)[0].tolist()), gradio('textbox-notebook'), gradio('tokens-notebook'), show_progress=False) + shared.gradio['get_tokens-notebook'].click(get_token_ids, gradio('textbox-notebook'), gradio('tokens-notebook'), show_progress=False)