mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-25 13:58:56 +01:00
Add support for logits processors in extensions (#3029)
This commit is contained in:
parent
eb823fce96
commit
6d1e911577
@ -106,15 +106,23 @@ def _apply_history_modifier_extensions(history):
|
||||
return history
|
||||
|
||||
|
||||
# Extension functions that override the default tokenizer output - currently only the first one will work
|
||||
# Extension functions that override the default tokenizer output - The order of execution is not defined
|
||||
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, function_name):
|
||||
return getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
|
||||
prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
|
||||
|
||||
return prompt, input_ids, input_embeds
|
||||
|
||||
|
||||
# Allow extensions to add their own logits processors to the stack being run.
|
||||
# Each extension would call `processor_list.append({their LogitsProcessor}())`.
|
||||
def _apply_logits_processor_extensions(function_name, processor_list, input_ids):
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, function_name):
|
||||
getattr(extension, function_name)(processor_list, input_ids)
|
||||
|
||||
|
||||
# Get prompt length in tokens after applying extension functions which override the default tokenizer output
|
||||
# currently only the first one will work
|
||||
def _apply_custom_tokenized_length(prompt):
|
||||
@ -183,6 +191,7 @@ EXTENSION_MAP = {
|
||||
"history": _apply_history_modifier_extensions,
|
||||
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
|
||||
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
|
||||
'logits_processor': partial(_apply_logits_processor_extensions, 'logits_processor_modifier'),
|
||||
"input_hijack": _apply_input_hijack,
|
||||
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt,
|
||||
"custom_generate_reply": _apply_custom_generate_reply,
|
||||
|
@ -8,6 +8,7 @@ import traceback
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import LogitsProcessorList
|
||||
|
||||
import modules.shared as shared
|
||||
from modules.callbacks import (
|
||||
@ -264,6 +265,13 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
||||
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
|
||||
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria())
|
||||
|
||||
processor = state.get('logits_processor', LogitsProcessorList([]))
|
||||
# In case folks just pass in a processor by itself.
|
||||
if type(processor) != LogitsProcessorList:
|
||||
processor = LogitsProcessorList([processor])
|
||||
apply_extensions('logits_processor', processor, input_ids)
|
||||
generate_params['logits_processor'] = processor
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
if not is_chat and not shared.is_seq2seq:
|
||||
|
Loading…
Reference in New Issue
Block a user