mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Add support for logits processors in extensions (#3029)
This commit is contained in:
parent
eb823fce96
commit
6d1e911577
@ -106,15 +106,23 @@ def _apply_history_modifier_extensions(history):
|
|||||||
return history
|
return history
|
||||||
|
|
||||||
|
|
||||||
# Extension functions that override the default tokenizer output - currently only the first one will work
|
# Extension functions that override the default tokenizer output - The order of execution is not defined
|
||||||
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
|
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
|
||||||
for extension, _ in iterator():
|
for extension, _ in iterator():
|
||||||
if hasattr(extension, function_name):
|
if hasattr(extension, function_name):
|
||||||
return getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
|
prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
|
||||||
|
|
||||||
return prompt, input_ids, input_embeds
|
return prompt, input_ids, input_embeds
|
||||||
|
|
||||||
|
|
||||||
|
# Allow extensions to add their own logits processors to the stack being run.
|
||||||
|
# Each extension would call `processor_list.append({their LogitsProcessor}())`.
|
||||||
|
def _apply_logits_processor_extensions(function_name, processor_list, input_ids):
|
||||||
|
for extension, _ in iterator():
|
||||||
|
if hasattr(extension, function_name):
|
||||||
|
getattr(extension, function_name)(processor_list, input_ids)
|
||||||
|
|
||||||
|
|
||||||
# Get prompt length in tokens after applying extension functions which override the default tokenizer output
|
# Get prompt length in tokens after applying extension functions which override the default tokenizer output
|
||||||
# currently only the first one will work
|
# currently only the first one will work
|
||||||
def _apply_custom_tokenized_length(prompt):
|
def _apply_custom_tokenized_length(prompt):
|
||||||
@ -183,6 +191,7 @@ EXTENSION_MAP = {
|
|||||||
"history": _apply_history_modifier_extensions,
|
"history": _apply_history_modifier_extensions,
|
||||||
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
|
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
|
||||||
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
|
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
|
||||||
|
'logits_processor': partial(_apply_logits_processor_extensions, 'logits_processor_modifier'),
|
||||||
"input_hijack": _apply_input_hijack,
|
"input_hijack": _apply_input_hijack,
|
||||||
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt,
|
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt,
|
||||||
"custom_generate_reply": _apply_custom_generate_reply,
|
"custom_generate_reply": _apply_custom_generate_reply,
|
||||||
|
@ -8,6 +8,7 @@ import traceback
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
from transformers import LogitsProcessorList
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.callbacks import (
|
from modules.callbacks import (
|
||||||
@ -264,6 +265,13 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||||||
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
|
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
|
||||||
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria())
|
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria())
|
||||||
|
|
||||||
|
processor = state.get('logits_processor', LogitsProcessorList([]))
|
||||||
|
# In case folks just pass in a processor by itself.
|
||||||
|
if type(processor) != LogitsProcessorList:
|
||||||
|
processor = LogitsProcessorList([processor])
|
||||||
|
apply_extensions('logits_processor', processor, input_ids)
|
||||||
|
generate_params['logits_processor'] = processor
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
try:
|
try:
|
||||||
if not is_chat and not shared.is_seq2seq:
|
if not is_chat and not shared.is_seq2seq:
|
||||||
|
Loading…
Reference in New Issue
Block a user