From 6d1e9115774a8b0adcbda38786af3856fe57e452 Mon Sep 17 00:00:00 2001 From: Morgan Schweers Date: Thu, 13 Jul 2023 13:22:41 -0700 Subject: [PATCH] Add support for logits processors in extensions (#3029) --- modules/extensions.py | 13 +++++++++++-- modules/text_generation.py | 8 ++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/modules/extensions.py b/modules/extensions.py index 8705101a..faf6cf6d 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -106,15 +106,23 @@ def _apply_history_modifier_extensions(history): return history -# Extension functions that override the default tokenizer output - currently only the first one will work +# Extension functions that override the default tokenizer output - The order of execution is not defined def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds): for extension, _ in iterator(): if hasattr(extension, function_name): - return getattr(extension, function_name)(state, prompt, input_ids, input_embeds) + prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds) return prompt, input_ids, input_embeds +# Allow extensions to add their own logits processors to the stack being run. +# Each extension would call `processor_list.append({their LogitsProcessor}())`. +def _apply_logits_processor_extensions(function_name, processor_list, input_ids): + for extension, _ in iterator(): + if hasattr(extension, function_name): + getattr(extension, function_name)(processor_list, input_ids) + + # Get prompt length in tokens after applying extension functions which override the default tokenizer output # currently only the first one will work def _apply_custom_tokenized_length(prompt): @@ -183,6 +191,7 @@ EXTENSION_MAP = { "history": _apply_history_modifier_extensions, "bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"), "tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"), + 'logits_processor': partial(_apply_logits_processor_extensions, 'logits_processor_modifier'), "input_hijack": _apply_input_hijack, "custom_generate_chat_prompt": _apply_custom_generate_chat_prompt, "custom_generate_reply": _apply_custom_generate_reply, diff --git a/modules/text_generation.py b/modules/text_generation.py index b7f6edf3..566c2f55 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -8,6 +8,7 @@ import traceback import numpy as np import torch import transformers +from transformers import LogitsProcessorList import modules.shared as shared from modules.callbacks import ( @@ -264,6 +265,13 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings generate_params['stopping_criteria'] = transformers.StoppingCriteriaList() generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria()) + processor = state.get('logits_processor', LogitsProcessorList([])) + # In case folks just pass in a processor by itself. + if type(processor) != LogitsProcessorList: + processor = LogitsProcessorList([processor]) + apply_extensions('logits_processor', processor, input_ids) + generate_params['logits_processor'] = processor + t0 = time.time() try: if not is_chat and not shared.is_seq2seq: