53 lines
2.4 KiB
Python
Raw Normal View History

import traceback
from importlib import import_module
from pathlib import Path
from typing import Tuple
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
from modules import shared
from modules.logging_colors import logger
def _get_available_pipeline_modules():
pipeline_path = Path(__file__).parent / 'pipelines'
modules = [p for p in pipeline_path.iterdir() if p.is_dir()]
return [m.name for m in modules if (m / 'pipelines.py').exists()]
def load_pipeline(params: dict) -> Tuple[AbstractMultimodalPipeline, str]:
pipeline_modules = {}
available_pipeline_modules = _get_available_pipeline_modules()
for name in available_pipeline_modules:
try:
pipeline_modules[name] = import_module(f'extensions.multimodal.pipelines.{name}.pipelines')
except:
logger.warning(f'Failed to get multimodal pipelines from {name}')
logger.warning(traceback.format_exc())
if shared.args.multimodal_pipeline is not None:
for k in pipeline_modules:
2023-05-09 20:20:35 -03:00
if hasattr(pipeline_modules[k], 'get_pipeline'):
pipeline = getattr(pipeline_modules[k], 'get_pipeline')(shared.args.multimodal_pipeline, params)
if pipeline is not None:
return (pipeline, k)
else:
model_name = shared.args.model.lower()
for k in pipeline_modules:
if hasattr(pipeline_modules[k], 'get_pipeline_from_model_name'):
pipeline = getattr(pipeline_modules[k], 'get_pipeline_from_model_name')(model_name, params)
if pipeline is not None:
return (pipeline, k)
available = []
for k in pipeline_modules:
if hasattr(pipeline_modules[k], 'available_pipelines'):
pipelines = getattr(pipeline_modules[k], 'available_pipelines')
available += pipelines
if shared.args.multimodal_pipeline is not None:
log = f'Multimodal - ERROR: Failed to load multimodal pipeline "{shared.args.multimodal_pipeline}", available pipelines are: {available}.'
else:
log = f'Multimodal - ERROR: Failed to determine multimodal pipeline for model {shared.args.model}, please select one manually using --multimodal-pipeline [PIPELINE]. Available pipelines are: {available}.'
logger.critical(f'{log} Please specify a correct pipeline, or disable the extension')
raise RuntimeError(f'{log} Please specify a correct pipeline, or disable the extension')