mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Add the option to use samplers in the logit viewer
This commit is contained in:
parent
25e5eaa6a6
commit
8545052c9d
@ -237,6 +237,11 @@ audio {
|
|||||||
border-radius: 0.4em;
|
border-radius: 0.4em;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.no-background {
|
||||||
|
background: var(--background-fill-primary) !important;
|
||||||
|
padding: 0px !important;
|
||||||
|
}
|
||||||
|
|
||||||
/*****************************************************/
|
/*****************************************************/
|
||||||
/*************** Chat UI declarations ****************/
|
/*************** Chat UI declarations ****************/
|
||||||
/*****************************************************/
|
/*****************************************************/
|
||||||
|
@ -82,3 +82,12 @@ observer.observe(targetElement, config);
|
|||||||
//------------------------------------------------
|
//------------------------------------------------
|
||||||
document.getElementById('chat-input').parentNode.style.background = 'transparent';
|
document.getElementById('chat-input').parentNode.style.background = 'transparent';
|
||||||
document.getElementById('chat-input').parentNode.style.border = 'none';
|
document.getElementById('chat-input').parentNode.style.border = 'none';
|
||||||
|
|
||||||
|
//------------------------------------------------
|
||||||
|
// Remove some backgrounds
|
||||||
|
//------------------------------------------------
|
||||||
|
const noBackgroundelements = document.querySelectorAll('.no-background');
|
||||||
|
for(i = 0; i < noBackgroundelements.length; i++) {
|
||||||
|
noBackgroundelements[i].parentNode.style.border = 'none';
|
||||||
|
noBackgroundelements[i].parentNode.parentNode.parentNode.style.alignItems = 'center';
|
||||||
|
}
|
||||||
|
@ -24,6 +24,7 @@ class Stream(transformers.StoppingCriteria):
|
|||||||
def __call__(self, input_ids, scores) -> bool:
|
def __call__(self, input_ids, scores) -> bool:
|
||||||
if self.callback_func is not None:
|
if self.callback_func is not None:
|
||||||
self.callback_func(input_ids[0])
|
self.callback_func(input_ids[0])
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,19 +1,30 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modules import shared
|
from modules import sampler_hijack, shared
|
||||||
|
from modules.text_generation import generate_reply
|
||||||
|
|
||||||
|
global_scores = None
|
||||||
|
|
||||||
|
|
||||||
def get_next_logits(prompt):
|
def get_next_logits(prompt, state, use_samplers, previous):
|
||||||
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda()
|
if use_samplers:
|
||||||
output = shared.model(input_ids=tokens)
|
state['max_new_tokens'] = 1
|
||||||
|
state['auto_max_new_tokens'] = False
|
||||||
|
for _ in generate_reply(prompt, state):
|
||||||
|
pass
|
||||||
|
|
||||||
|
scores = sampler_hijack.global_scores[-1]
|
||||||
|
else:
|
||||||
|
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda()
|
||||||
|
output = shared.model(input_ids=tokens)
|
||||||
|
scores = output['logits'][-1][-1]
|
||||||
|
|
||||||
scores = output['logits'][-1][-1]
|
|
||||||
probs = torch.softmax(scores, dim=-1, dtype=torch.float)
|
probs = torch.softmax(scores, dim=-1, dtype=torch.float)
|
||||||
|
|
||||||
topk_values, topk_indices = torch.topk(probs, k=20, largest=True, sorted=True)
|
topk_values, topk_indices = torch.topk(probs, k=20, largest=True, sorted=True)
|
||||||
topk_values = [f"{float(i):.5f}" % i for i in topk_values]
|
topk_values = [f"{float(i):.5f}" for i in topk_values]
|
||||||
|
|
||||||
output = ''
|
output = ''
|
||||||
for row in list(zip(topk_values, shared.tokenizer.convert_ids_to_tokens(topk_indices))):
|
for row in list(zip(topk_values, shared.tokenizer.convert_ids_to_tokens(topk_indices))):
|
||||||
output += f"{row[0]} {row[1]}\n"
|
output += f"{row[0]} - {row[1]}\n"
|
||||||
|
|
||||||
return output
|
return output, previous
|
||||||
|
@ -10,6 +10,8 @@ from transformers.generation.logits_process import (
|
|||||||
TemperatureLogitsWarper
|
TemperatureLogitsWarper
|
||||||
)
|
)
|
||||||
|
|
||||||
|
global_scores = None
|
||||||
|
|
||||||
|
|
||||||
class TailFreeLogitsWarper(LogitsWarper):
|
class TailFreeLogitsWarper(LogitsWarper):
|
||||||
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||||
@ -122,6 +124,16 @@ class MirostatLogitsWarper(LogitsWarper):
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class SpyLogitsWarper(LogitsWarper):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
global global_scores
|
||||||
|
global_scores = scores
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
|
class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
|
||||||
'''
|
'''
|
||||||
Copied from the transformers library
|
Copied from the transformers library
|
||||||
@ -168,6 +180,7 @@ def get_logits_warper_patch(self, generation_config):
|
|||||||
else:
|
else:
|
||||||
warpers += warpers_to_add
|
warpers += warpers_to_add
|
||||||
|
|
||||||
|
warpers.append(SpyLogitsWarper())
|
||||||
return warpers
|
return warpers
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ def create_ui():
|
|||||||
with gr.Column(scale=5):
|
with gr.Column(scale=5):
|
||||||
lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file')
|
lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file')
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).')
|
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).', elem_classes=['no-background'])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
@ -44,8 +44,15 @@ def create_ui():
|
|||||||
shared.gradio['html-default'] = gr.HTML()
|
shared.gradio['html-default'] = gr.HTML()
|
||||||
|
|
||||||
with gr.Tab('Logits'):
|
with gr.Tab('Logits'):
|
||||||
shared.gradio['get_logits-default'] = gr.Button('Get next token probabilities')
|
with gr.Row():
|
||||||
shared.gradio['logits-default'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits', 'add_scrollbar'])
|
with gr.Column(scale=10):
|
||||||
|
shared.gradio['get_logits-default'] = gr.Button('Get next token probabilities')
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
shared.gradio['use_samplers-default'] = gr.Checkbox(label='Use samplers', value=True, elem_classes=['no-background'])
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['logits-default'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits', 'add_scrollbar'])
|
||||||
|
shared.gradio['logits-default-previous'] = gr.Textbox(lines=23, label='Previous output', elem_classes=['textbox_logits', 'add_scrollbar'])
|
||||||
|
|
||||||
|
|
||||||
def create_event_handlers():
|
def create_event_handlers():
|
||||||
@ -83,5 +90,7 @@ def create_event_handlers():
|
|||||||
lambda x: x + '.txt', gradio('prompt_menu-default'), gradio('delete_filename')).then(
|
lambda x: x + '.txt', gradio('prompt_menu-default'), gradio('delete_filename')).then(
|
||||||
lambda: gr.update(visible=True), None, gradio('file_deleter'))
|
lambda: gr.update(visible=True), None, gradio('file_deleter'))
|
||||||
|
|
||||||
shared.gradio['textbox-default'].change(lambda x : f"<span>{count_tokens(x)}</span>", gradio('textbox-default'), gradio('token-counter-default'), show_progress=False)
|
shared.gradio['textbox-default'].change(lambda x: f"<span>{count_tokens(x)}</span>", gradio('textbox-default'), gradio('token-counter-default'), show_progress=False)
|
||||||
shared.gradio['get_logits-default'].click(logits.get_next_logits, gradio('textbox-default'), gradio('logits-default'), show_progress=False)
|
shared.gradio['get_logits-default'].click(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
logits.get_next_logits, gradio('textbox-default', 'interface_state', 'use_samplers-default', 'logits-default'), gradio('logits-default', 'logits-default-previous'), show_progress=False)
|
||||||
|
@ -30,8 +30,15 @@ def create_ui():
|
|||||||
shared.gradio['html-notebook'] = gr.HTML()
|
shared.gradio['html-notebook'] = gr.HTML()
|
||||||
|
|
||||||
with gr.Tab('Logits'):
|
with gr.Tab('Logits'):
|
||||||
shared.gradio['get_logits-notebook'] = gr.Button('Get next token probabilities')
|
with gr.Row():
|
||||||
shared.gradio['logits-notebook'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits_notebook', 'add_scrollbar'])
|
with gr.Column(scale=10):
|
||||||
|
shared.gradio['get_logits-notebook'] = gr.Button('Get next token probabilities')
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
shared.gradio['use_samplers-notebook'] = gr.Checkbox(label='Use samplers', value=True, elem_classes=['no-background'])
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['logits-notebook'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits', 'add_scrollbar'])
|
||||||
|
shared.gradio['logits-notebook-previous'] = gr.Textbox(lines=23, label='Previous output', elem_classes=['textbox_logits', 'add_scrollbar'])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['Generate-notebook'] = gr.Button('Generate', variant='primary', elem_classes='small-button')
|
shared.gradio['Generate-notebook'] = gr.Button('Generate', variant='primary', elem_classes='small-button')
|
||||||
@ -85,5 +92,7 @@ def create_event_handlers():
|
|||||||
lambda x: x + '.txt', gradio('prompt_menu-notebook'), gradio('delete_filename')).then(
|
lambda x: x + '.txt', gradio('prompt_menu-notebook'), gradio('delete_filename')).then(
|
||||||
lambda: gr.update(visible=True), None, gradio('file_deleter'))
|
lambda: gr.update(visible=True), None, gradio('file_deleter'))
|
||||||
|
|
||||||
shared.gradio['textbox-notebook'].input(lambda x : f"<span>{count_tokens(x)}</span>", gradio('textbox-notebook'), gradio('token-counter-notebook'), show_progress=False)
|
shared.gradio['textbox-notebook'].input(lambda x: f"<span>{count_tokens(x)}</span>", gradio('textbox-notebook'), gradio('token-counter-notebook'), show_progress=False)
|
||||||
shared.gradio['get_logits-notebook'].click(logits.get_next_logits, gradio('textbox-notebook'), gradio('logits-notebook'))
|
shared.gradio['get_logits-notebook'].click(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
logits.get_next_logits, gradio('textbox-notebook', 'interface_state', 'use_samplers-notebook', 'logits-notebook'), gradio('logits-notebook', 'logits-notebook-previous'), show_progress=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user