Source code for czbenchmarks.models.implementations.base_model_implementation

import glob
import logging
import os
import pathlib
from abc import ABC, abstractmethod
from typing import List, Optional, Any
from argparse import ArgumentParser, Namespace

from ...constants import (
    INPUT_DATA_PATH_DOCKER,
    MODEL_WEIGHTS_PATH_DOCKER,
    OUTPUT_DATA_PATH_DOCKER,
    get_base_name,
    get_numbered_path,
)
from ...datasets import BaseDataset
from ..validators.base_model_validator import BaseModelValidator

# Configure logging to output to stdout
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    force=True,  # This ensures the configuration is applied
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


[docs] class BaseModelImplementation(BaseModelValidator, ABC): """Abstract base class for model implementations. Handles common model operations like: - Loading datasets - Downloading model weights - Running inference - Saving results """ datasets: List[BaseDataset] """List of datasets to process""" model_weights_dir: str """Directory containing model weights""" args: Namespace """Parsed command-line arguments""" def __init__(self, **user_kwargs): super().__init__() # User supplied arguments take precedence over CLI arguments # This allows for easy overriding of default values # while still allowing for CLI argument parsing cli_args = self.parse_args() self.args = self._merge_arguments( cli_args=vars(cli_args), user_args=user_kwargs ) self.args = Namespace(**self.args)
[docs] @abstractmethod def get_model_weights_subdir(self, dataset: BaseDataset) -> str: """Get subdirectory for model variant weights. Args: dataset: Dataset being processed Returns: Subdirectory path or empty string if not needed """
@abstractmethod def _download_model_weights(self, dataset: BaseDataset): """Download model weights if needed. Args: dataset: Dataset being processed """
[docs] def download_model_weights(self, dataset: BaseDataset) -> None: """Download and verify model weights. Args: dataset: Dataset being processed """ model_weights_subdir = self.get_model_weights_subdir(dataset) if model_weights_subdir != "": self.model_weights_dir = ( f"{MODEL_WEIGHTS_PATH_DOCKER}/{model_weights_subdir}" ) else: self.model_weights_dir = MODEL_WEIGHTS_PATH_DOCKER if not os.path.exists(self.model_weights_dir) or not any( os.listdir(self.model_weights_dir) ): logger.info("Downloading model weights...") self._download_model_weights(dataset) logger.info("Model weights downloaded successfully") else: logger.info("Model weights already downloaded...")
[docs] @abstractmethod def run_model(self, dataset: BaseDataset) -> None: """Implement model-specific inference logic"""
[docs] def parse_args(self) -> Namespace: """Centralized argument parsing using subclass-defined parser""" parser = self.create_parser() args = Namespace() if parser: args = parser.parse_known_args()[0] return args
def _merge_arguments( self, cli_args: dict[str, Any], user_args: dict[str, Any] ) -> dict[str, Any]: """Merge arguments with user input as priority""" return {**cli_args, **user_args}
[docs] def create_parser(self) -> ArgumentParser | None: """Subclasses implement to define their CLI arguments""" return None
[docs] def run(self, datasets: Optional[BaseDataset | List[BaseDataset]] = None): """Run the full model pipeline. 1. Load and validate all datasets 2. Download model weights 3. Run inference 4. Save results Args: datasets: List of datasets to process Raises: FileNotFoundError: If no input datasets found """ if datasets: serialize_datasets = False self.datasets = datasets if isinstance(datasets, list) else [datasets] else: serialize_datasets = True # Find all input datasets input_dir = pathlib.Path(INPUT_DATA_PATH_DOCKER).parent base_pattern = get_base_name(INPUT_DATA_PATH_DOCKER) input_files = sorted(glob.glob(os.path.join(input_dir, base_pattern))) if not input_files: raise FileNotFoundError("No input datasets found") # Ensure base file comes first if it exists if INPUT_DATA_PATH_DOCKER in input_files: input_files.remove(INPUT_DATA_PATH_DOCKER) input_files = [INPUT_DATA_PATH_DOCKER] + input_files # Load all datasets self.datasets = [ self.dataset_type.deserialize(input_file) for input_file in input_files ] logger.info("Loading data...") for dataset in self.datasets: dataset.load_data() logger.info("Data loaded successfully") logger.info("Validating data...") for dataset in self.datasets: dataset.validate() self.validate_dataset(dataset) logger.info("Data validated successfully") for dataset in self.datasets: self.download_model_weights(dataset) logger.info("Running model...") for dataset in self.datasets: self.run_model(dataset) logger.info("Model ran successfully") if serialize_datasets: # Unload and serialize all datasets for i, dataset in enumerate(self.datasets): dataset.unload_data() output_path = get_numbered_path(OUTPUT_DATA_PATH_DOCKER, i) dataset.serialize(output_path)