From 8545052c9d994370b110047e634c4593d02d50f9 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Tue, 22 Aug 2023 20:18:16 -0700
Subject: [PATCH] Add the option to use samplers in the logit viewer
---
css/main.css | 5 +++++
js/main.js | 9 +++++++++
modules/callbacks.py | 1 +
modules/logits.py | 29 ++++++++++++++++++++---------
modules/sampler_hijack.py | 13 +++++++++++++
modules/training.py | 2 +-
modules/ui_default.py | 17 +++++++++++++----
modules/ui_notebook.py | 17 +++++++++++++----
8 files changed, 75 insertions(+), 18 deletions(-)
diff --git a/css/main.css b/css/main.css
index 3408375c..405b57e0 100644
--- a/css/main.css
+++ b/css/main.css
@@ -237,6 +237,11 @@ audio {
border-radius: 0.4em;
}
+.no-background {
+ background: var(--background-fill-primary) !important;
+ padding: 0px !important;
+}
+
/*****************************************************/
/*************** Chat UI declarations ****************/
/*****************************************************/
diff --git a/js/main.js b/js/main.js
index 6a27c3b4..e409cc3d 100644
--- a/js/main.js
+++ b/js/main.js
@@ -82,3 +82,12 @@ observer.observe(targetElement, config);
//------------------------------------------------
document.getElementById('chat-input').parentNode.style.background = 'transparent';
document.getElementById('chat-input').parentNode.style.border = 'none';
+
+//------------------------------------------------
+// Remove some backgrounds
+//------------------------------------------------
+const noBackgroundelements = document.querySelectorAll('.no-background');
+for(i = 0; i < noBackgroundelements.length; i++) {
+ noBackgroundelements[i].parentNode.style.border = 'none';
+ noBackgroundelements[i].parentNode.parentNode.parentNode.style.alignItems = 'center';
+}
diff --git a/modules/callbacks.py b/modules/callbacks.py
index 1fa95e47..e29e397d 100644
--- a/modules/callbacks.py
+++ b/modules/callbacks.py
@@ -24,6 +24,7 @@ class Stream(transformers.StoppingCriteria):
def __call__(self, input_ids, scores) -> bool:
if self.callback_func is not None:
self.callback_func(input_ids[0])
+
return False
diff --git a/modules/logits.py b/modules/logits.py
index 99cb336f..3bfeb6b0 100644
--- a/modules/logits.py
+++ b/modules/logits.py
@@ -1,19 +1,30 @@
import torch
-from modules import shared
+from modules import sampler_hijack, shared
+from modules.text_generation import generate_reply
+
+global_scores = None
-def get_next_logits(prompt):
- tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda()
- output = shared.model(input_ids=tokens)
+def get_next_logits(prompt, state, use_samplers, previous):
+ if use_samplers:
+ state['max_new_tokens'] = 1
+ state['auto_max_new_tokens'] = False
+ for _ in generate_reply(prompt, state):
+ pass
+
+ scores = sampler_hijack.global_scores[-1]
+ else:
+ tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda()
+ output = shared.model(input_ids=tokens)
+ scores = output['logits'][-1][-1]
- scores = output['logits'][-1][-1]
probs = torch.softmax(scores, dim=-1, dtype=torch.float)
-
topk_values, topk_indices = torch.topk(probs, k=20, largest=True, sorted=True)
- topk_values = [f"{float(i):.5f}" % i for i in topk_values]
+ topk_values = [f"{float(i):.5f}" for i in topk_values]
+
output = ''
for row in list(zip(topk_values, shared.tokenizer.convert_ids_to_tokens(topk_indices))):
- output += f"{row[0]} {row[1]}\n"
+ output += f"{row[0]} - {row[1]}\n"
- return output
+ return output, previous
diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py
index d5ebbb76..0a724f47 100644
--- a/modules/sampler_hijack.py
+++ b/modules/sampler_hijack.py
@@ -10,6 +10,8 @@ from transformers.generation.logits_process import (
TemperatureLogitsWarper
)
+global_scores = None
+
class TailFreeLogitsWarper(LogitsWarper):
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
@@ -122,6 +124,16 @@ class MirostatLogitsWarper(LogitsWarper):
return scores
+class SpyLogitsWarper(LogitsWarper):
+ def __init__(self):
+ pass
+
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+ global global_scores
+ global_scores = scores
+ return scores
+
+
class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
'''
Copied from the transformers library
@@ -168,6 +180,7 @@ def get_logits_warper_patch(self, generation_config):
else:
warpers += warpers_to_add
+ warpers.append(SpyLogitsWarper())
return warpers
diff --git a/modules/training.py b/modules/training.py
index 7be0d24f..a993f6f0 100644
--- a/modules/training.py
+++ b/modules/training.py
@@ -64,7 +64,7 @@ def create_ui():
with gr.Column(scale=5):
lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file')
with gr.Column():
- always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).')
+ always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).', elem_classes=['no-background'])
with gr.Row():
with gr.Column():
diff --git a/modules/ui_default.py b/modules/ui_default.py
index 5470a6ad..29b9bee5 100644
--- a/modules/ui_default.py
+++ b/modules/ui_default.py
@@ -44,8 +44,15 @@ def create_ui():
shared.gradio['html-default'] = gr.HTML()
with gr.Tab('Logits'):
- shared.gradio['get_logits-default'] = gr.Button('Get next token probabilities')
- shared.gradio['logits-default'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits', 'add_scrollbar'])
+ with gr.Row():
+ with gr.Column(scale=10):
+ shared.gradio['get_logits-default'] = gr.Button('Get next token probabilities')
+ with gr.Column(scale=1):
+ shared.gradio['use_samplers-default'] = gr.Checkbox(label='Use samplers', value=True, elem_classes=['no-background'])
+
+ with gr.Row():
+ shared.gradio['logits-default'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits', 'add_scrollbar'])
+ shared.gradio['logits-default-previous'] = gr.Textbox(lines=23, label='Previous output', elem_classes=['textbox_logits', 'add_scrollbar'])
def create_event_handlers():
@@ -83,5 +90,7 @@ def create_event_handlers():
lambda x: x + '.txt', gradio('prompt_menu-default'), gradio('delete_filename')).then(
lambda: gr.update(visible=True), None, gradio('file_deleter'))
- shared.gradio['textbox-default'].change(lambda x : f"{count_tokens(x)}", gradio('textbox-default'), gradio('token-counter-default'), show_progress=False)
- shared.gradio['get_logits-default'].click(logits.get_next_logits, gradio('textbox-default'), gradio('logits-default'), show_progress=False)
+ shared.gradio['textbox-default'].change(lambda x: f"{count_tokens(x)}", gradio('textbox-default'), gradio('token-counter-default'), show_progress=False)
+ shared.gradio['get_logits-default'].click(
+ ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
+ logits.get_next_logits, gradio('textbox-default', 'interface_state', 'use_samplers-default', 'logits-default'), gradio('logits-default', 'logits-default-previous'), show_progress=False)
diff --git a/modules/ui_notebook.py b/modules/ui_notebook.py
index 7fbf7a85..9ff0c3fe 100644
--- a/modules/ui_notebook.py
+++ b/modules/ui_notebook.py
@@ -30,8 +30,15 @@ def create_ui():
shared.gradio['html-notebook'] = gr.HTML()
with gr.Tab('Logits'):
- shared.gradio['get_logits-notebook'] = gr.Button('Get next token probabilities')
- shared.gradio['logits-notebook'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits_notebook', 'add_scrollbar'])
+ with gr.Row():
+ with gr.Column(scale=10):
+ shared.gradio['get_logits-notebook'] = gr.Button('Get next token probabilities')
+ with gr.Column(scale=1):
+ shared.gradio['use_samplers-notebook'] = gr.Checkbox(label='Use samplers', value=True, elem_classes=['no-background'])
+
+ with gr.Row():
+ shared.gradio['logits-notebook'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits', 'add_scrollbar'])
+ shared.gradio['logits-notebook-previous'] = gr.Textbox(lines=23, label='Previous output', elem_classes=['textbox_logits', 'add_scrollbar'])
with gr.Row():
shared.gradio['Generate-notebook'] = gr.Button('Generate', variant='primary', elem_classes='small-button')
@@ -85,5 +92,7 @@ def create_event_handlers():
lambda x: x + '.txt', gradio('prompt_menu-notebook'), gradio('delete_filename')).then(
lambda: gr.update(visible=True), None, gradio('file_deleter'))
- shared.gradio['textbox-notebook'].input(lambda x : f"{count_tokens(x)}", gradio('textbox-notebook'), gradio('token-counter-notebook'), show_progress=False)
- shared.gradio['get_logits-notebook'].click(logits.get_next_logits, gradio('textbox-notebook'), gradio('logits-notebook'))
+ shared.gradio['textbox-notebook'].input(lambda x: f"{count_tokens(x)}", gradio('textbox-notebook'), gradio('token-counter-notebook'), show_progress=False)
+ shared.gradio['get_logits-notebook'].click(
+ ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
+ logits.get_next_logits, gradio('textbox-notebook', 'interface_state', 'use_samplers-notebook', 'logits-notebook'), gradio('logits-notebook', 'logits-notebook-previous'), show_progress=False)