mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Add the option to use samplers in the logit viewer
This commit is contained in:
parent
25e5eaa6a6
commit
8545052c9d
@ -237,6 +237,11 @@ audio {
|
||||
border-radius: 0.4em;
|
||||
}
|
||||
|
||||
.no-background {
|
||||
background: var(--background-fill-primary) !important;
|
||||
padding: 0px !important;
|
||||
}
|
||||
|
||||
/*****************************************************/
|
||||
/*************** Chat UI declarations ****************/
|
||||
/*****************************************************/
|
||||
|
@ -82,3 +82,12 @@ observer.observe(targetElement, config);
|
||||
//------------------------------------------------
|
||||
document.getElementById('chat-input').parentNode.style.background = 'transparent';
|
||||
document.getElementById('chat-input').parentNode.style.border = 'none';
|
||||
|
||||
//------------------------------------------------
|
||||
// Remove some backgrounds
|
||||
//------------------------------------------------
|
||||
const noBackgroundelements = document.querySelectorAll('.no-background');
|
||||
for(i = 0; i < noBackgroundelements.length; i++) {
|
||||
noBackgroundelements[i].parentNode.style.border = 'none';
|
||||
noBackgroundelements[i].parentNode.parentNode.parentNode.style.alignItems = 'center';
|
||||
}
|
||||
|
@ -24,6 +24,7 @@ class Stream(transformers.StoppingCriteria):
|
||||
def __call__(self, input_ids, scores) -> bool:
|
||||
if self.callback_func is not None:
|
||||
self.callback_func(input_ids[0])
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
@ -1,19 +1,30 @@
|
||||
import torch
|
||||
|
||||
from modules import shared
|
||||
from modules import sampler_hijack, shared
|
||||
from modules.text_generation import generate_reply
|
||||
|
||||
global_scores = None
|
||||
|
||||
|
||||
def get_next_logits(prompt):
|
||||
def get_next_logits(prompt, state, use_samplers, previous):
|
||||
if use_samplers:
|
||||
state['max_new_tokens'] = 1
|
||||
state['auto_max_new_tokens'] = False
|
||||
for _ in generate_reply(prompt, state):
|
||||
pass
|
||||
|
||||
scores = sampler_hijack.global_scores[-1]
|
||||
else:
|
||||
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda()
|
||||
output = shared.model(input_ids=tokens)
|
||||
|
||||
scores = output['logits'][-1][-1]
|
||||
probs = torch.softmax(scores, dim=-1, dtype=torch.float)
|
||||
|
||||
probs = torch.softmax(scores, dim=-1, dtype=torch.float)
|
||||
topk_values, topk_indices = torch.topk(probs, k=20, largest=True, sorted=True)
|
||||
topk_values = [f"{float(i):.5f}" % i for i in topk_values]
|
||||
topk_values = [f"{float(i):.5f}" for i in topk_values]
|
||||
|
||||
output = ''
|
||||
for row in list(zip(topk_values, shared.tokenizer.convert_ids_to_tokens(topk_indices))):
|
||||
output += f"{row[0]} {row[1]}\n"
|
||||
output += f"{row[0]} - {row[1]}\n"
|
||||
|
||||
return output
|
||||
return output, previous
|
||||
|
@ -10,6 +10,8 @@ from transformers.generation.logits_process import (
|
||||
TemperatureLogitsWarper
|
||||
)
|
||||
|
||||
global_scores = None
|
||||
|
||||
|
||||
class TailFreeLogitsWarper(LogitsWarper):
|
||||
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
@ -122,6 +124,16 @@ class MirostatLogitsWarper(LogitsWarper):
|
||||
return scores
|
||||
|
||||
|
||||
class SpyLogitsWarper(LogitsWarper):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
global global_scores
|
||||
global_scores = scores
|
||||
return scores
|
||||
|
||||
|
||||
class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
|
||||
'''
|
||||
Copied from the transformers library
|
||||
@ -168,6 +180,7 @@ def get_logits_warper_patch(self, generation_config):
|
||||
else:
|
||||
warpers += warpers_to_add
|
||||
|
||||
warpers.append(SpyLogitsWarper())
|
||||
return warpers
|
||||
|
||||
|
||||
|
@ -64,7 +64,7 @@ def create_ui():
|
||||
with gr.Column(scale=5):
|
||||
lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file')
|
||||
with gr.Column():
|
||||
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).')
|
||||
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).', elem_classes=['no-background'])
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
|
@ -44,8 +44,15 @@ def create_ui():
|
||||
shared.gradio['html-default'] = gr.HTML()
|
||||
|
||||
with gr.Tab('Logits'):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=10):
|
||||
shared.gradio['get_logits-default'] = gr.Button('Get next token probabilities')
|
||||
with gr.Column(scale=1):
|
||||
shared.gradio['use_samplers-default'] = gr.Checkbox(label='Use samplers', value=True, elem_classes=['no-background'])
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['logits-default'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits', 'add_scrollbar'])
|
||||
shared.gradio['logits-default-previous'] = gr.Textbox(lines=23, label='Previous output', elem_classes=['textbox_logits', 'add_scrollbar'])
|
||||
|
||||
|
||||
def create_event_handlers():
|
||||
@ -83,5 +90,7 @@ def create_event_handlers():
|
||||
lambda x: x + '.txt', gradio('prompt_menu-default'), gradio('delete_filename')).then(
|
||||
lambda: gr.update(visible=True), None, gradio('file_deleter'))
|
||||
|
||||
shared.gradio['textbox-default'].change(lambda x : f"<span>{count_tokens(x)}</span>", gradio('textbox-default'), gradio('token-counter-default'), show_progress=False)
|
||||
shared.gradio['get_logits-default'].click(logits.get_next_logits, gradio('textbox-default'), gradio('logits-default'), show_progress=False)
|
||||
shared.gradio['textbox-default'].change(lambda x: f"<span>{count_tokens(x)}</span>", gradio('textbox-default'), gradio('token-counter-default'), show_progress=False)
|
||||
shared.gradio['get_logits-default'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
logits.get_next_logits, gradio('textbox-default', 'interface_state', 'use_samplers-default', 'logits-default'), gradio('logits-default', 'logits-default-previous'), show_progress=False)
|
||||
|
@ -30,8 +30,15 @@ def create_ui():
|
||||
shared.gradio['html-notebook'] = gr.HTML()
|
||||
|
||||
with gr.Tab('Logits'):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=10):
|
||||
shared.gradio['get_logits-notebook'] = gr.Button('Get next token probabilities')
|
||||
shared.gradio['logits-notebook'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits_notebook', 'add_scrollbar'])
|
||||
with gr.Column(scale=1):
|
||||
shared.gradio['use_samplers-notebook'] = gr.Checkbox(label='Use samplers', value=True, elem_classes=['no-background'])
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['logits-notebook'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits', 'add_scrollbar'])
|
||||
shared.gradio['logits-notebook-previous'] = gr.Textbox(lines=23, label='Previous output', elem_classes=['textbox_logits', 'add_scrollbar'])
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['Generate-notebook'] = gr.Button('Generate', variant='primary', elem_classes='small-button')
|
||||
@ -85,5 +92,7 @@ def create_event_handlers():
|
||||
lambda x: x + '.txt', gradio('prompt_menu-notebook'), gradio('delete_filename')).then(
|
||||
lambda: gr.update(visible=True), None, gradio('file_deleter'))
|
||||
|
||||
shared.gradio['textbox-notebook'].input(lambda x : f"<span>{count_tokens(x)}</span>", gradio('textbox-notebook'), gradio('token-counter-notebook'), show_progress=False)
|
||||
shared.gradio['get_logits-notebook'].click(logits.get_next_logits, gradio('textbox-notebook'), gradio('logits-notebook'))
|
||||
shared.gradio['textbox-notebook'].input(lambda x: f"<span>{count_tokens(x)}</span>", gradio('textbox-notebook'), gradio('token-counter-notebook'), show_progress=False)
|
||||
shared.gradio['get_logits-notebook'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
logits.get_next_logits, gradio('textbox-notebook', 'interface_state', 'use_samplers-notebook', 'logits-notebook'), gradio('logits-notebook', 'logits-notebook-previous'), show_progress=False)
|
||||
|
Loading…
Reference in New Issue
Block a user