Add support for logits processors in extensions (#3029)

This commit is contained in:
Morgan Schweers 2023-07-13 13:22:41 -07:00 committed by GitHub
parent eb823fce96
commit 6d1e911577
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 2 deletions

View File

@ -106,15 +106,23 @@ def _apply_history_modifier_extensions(history):
return history return history
# Extension functions that override the default tokenizer output - currently only the first one will work # Extension functions that override the default tokenizer output - The order of execution is not defined
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds): def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
for extension, _ in iterator(): for extension, _ in iterator():
if hasattr(extension, function_name): if hasattr(extension, function_name):
return getattr(extension, function_name)(state, prompt, input_ids, input_embeds) prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
return prompt, input_ids, input_embeds return prompt, input_ids, input_embeds
# Allow extensions to add their own logits processors to the stack being run.
# Each extension would call `processor_list.append({their LogitsProcessor}())`.
def _apply_logits_processor_extensions(function_name, processor_list, input_ids):
for extension, _ in iterator():
if hasattr(extension, function_name):
getattr(extension, function_name)(processor_list, input_ids)
# Get prompt length in tokens after applying extension functions which override the default tokenizer output # Get prompt length in tokens after applying extension functions which override the default tokenizer output
# currently only the first one will work # currently only the first one will work
def _apply_custom_tokenized_length(prompt): def _apply_custom_tokenized_length(prompt):
@ -183,6 +191,7 @@ EXTENSION_MAP = {
"history": _apply_history_modifier_extensions, "history": _apply_history_modifier_extensions,
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"), "bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"), "tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
'logits_processor': partial(_apply_logits_processor_extensions, 'logits_processor_modifier'),
"input_hijack": _apply_input_hijack, "input_hijack": _apply_input_hijack,
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt, "custom_generate_chat_prompt": _apply_custom_generate_chat_prompt,
"custom_generate_reply": _apply_custom_generate_reply, "custom_generate_reply": _apply_custom_generate_reply,

View File

@ -8,6 +8,7 @@ import traceback
import numpy as np import numpy as np
import torch import torch
import transformers import transformers
from transformers import LogitsProcessorList
import modules.shared as shared import modules.shared as shared
from modules.callbacks import ( from modules.callbacks import (
@ -264,6 +265,13 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList() generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria()) generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria())
processor = state.get('logits_processor', LogitsProcessorList([]))
# In case folks just pass in a processor by itself.
if type(processor) != LogitsProcessorList:
processor = LogitsProcessorList([processor])
apply_extensions('logits_processor', processor, input_ids)
generate_params['logits_processor'] = processor
t0 = time.time() t0 = time.time()
try: try:
if not is_chat and not shared.is_seq2seq: if not is_chat and not shared.is_seq2seq: