Source code for czbenchmarks.tasks.base

from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Set, Union

from ..datasets import BaseDataset, DataType
from ..models.types import ModelType
from ..metrics.types import MetricResult
from .utils import run_standard_scrna_workflow


[docs] class BaseTask(ABC): """Abstract base class for all benchmark tasks. Defines the interface that all tasks must implement. Tasks are responsible for: 1. Declaring their required input/output data types 2. Running task-specific computations 3. Computing evaluation metrics Tasks should store any intermediate results as instance variables to be used in metric computation. """ @property @abstractmethod def display_name(self) -> str: """A pretty name to use when displaying task results""" @property @abstractmethod def required_inputs(self) -> Set[DataType]: """Required input data types this task requires. Returns: Set of DataType enums that must be present in input data """ @property @abstractmethod def required_outputs(self) -> Set[DataType]: """Required output types from models this task requires Returns: Set of DataType enums that must be present in output data """ @property def requires_multiple_datasets(self) -> bool: """Whether this task requires multiple datasets""" return False
[docs] def validate(self, data: BaseDataset): error_msg = [] missing_inputs = self.required_inputs - set(data.inputs.keys()) if missing_inputs: error_msg.append(f"Missing required inputs: {missing_inputs}") # Check if there are any model outputs at all if not data.outputs: error_msg.append("No model outputs available") else: for model_type in data.outputs: missing_outputs = self.required_outputs - set( data.outputs[model_type].keys() ) if missing_outputs: error_msg.append( "Missing required outputs for model type " f"{model_type.name}: {missing_outputs}" ) if error_msg: raise ValueError( f"Data validation failed for {self.__class__.__name__}: " f"{' | '.join(error_msg)}" ) data.validate()
@abstractmethod def _run_task(self, data: BaseDataset, model_type: ModelType): """Run the task's core computation. Should store any intermediate results needed for metric computation as instance variables. Args: data: Dataset containing required input and output data Returns: Modified or unmodified dataset """ @abstractmethod def _compute_metrics(self) -> List[MetricResult]: """Compute evaluation metrics for the task. Returns: List of MetricResult objects containing metric values and metadata """ def _run_task_for_dataset( self, data: Union[BaseDataset, List[BaseDataset]], model_types: Optional[List[ModelType]] = None, ) -> Dict[ModelType, List[MetricResult]]: """Run task for a dataset or list of datasets and compute metrics for each model. This method determines which model types to evaluate, validates their availability, runs the task implementation for each model type, and computes the corresponding metrics. Args: data: Single dataset or list of datasets containing required input and output data model_types: Optional list of specific model types to evaluate. If None, will use all available model types common across datasets. Returns: Dictionary mapping model types to their list of MetricResult objects Raises: ValueError: If no common model types found across datasets or if a specified model type is not available in a dataset """ # Dictionary to store metrics for each model type all_metrics_per_model = {} # Determine which model types to evaluate if not explicitly provided if model_types is None: if isinstance(data, list): # For multiple datasets, find model types available in all datasets model_types = set.intersection( *[set(dataset.outputs.keys()) for dataset in data] ) if not model_types: raise ValueError("No common model types found across all datasets") else: # For single dataset, use all available model types model_types = list(data.outputs.keys()) # Validate that all requested model types are available in all datasets for model_type in model_types: if isinstance(data, list): for dataset in data: if model_type not in dataset.outputs: raise ValueError( f"Model type {model_type} not found in dataset" ) else: if model_type not in data.outputs: raise ValueError(f"Model type {model_type} not found in dataset") # Process each model type for model_type in model_types: # Run the task implementation for this model self._run_task(data, model_type) # Compute metrics based on task results metrics = self._compute_metrics() # Store metrics for this model type all_metrics_per_model[model_type] = metrics return all_metrics_per_model
[docs] def set_baseline(self, data: BaseDataset, **kwargs): """Set a baseline embedding using PCA on gene expression data. This method performs standard preprocessing on the raw gene expression data and uses PCA for dimensionality reduction. It then sets the PCA embedding as the BASELINE model output in the dataset, which can be used for comparison with other model embeddings. Args: data: BaseDataset containing AnnData with gene expression data **kwargs: Additional arguments passed to run_standard_scrna_workflow """ # Get the AnnData object from the dataset adata = data.get_input(DataType.ANNDATA) # Run the standard preprocessing workflow adata_baseline = run_standard_scrna_workflow(adata, **kwargs) # Use PCA result as the embedding for clustering data.set_output( ModelType.BASELINE, DataType.EMBEDDING, adata_baseline.obsm["X_pca"] )
[docs] def run( self, data: Union[BaseDataset, List[BaseDataset]], model_types: Optional[List[ModelType]] = None, ) -> Union[ Dict[ModelType, List[MetricResult]], List[Dict[ModelType, List[MetricResult]]], ]: """Run the task on input data and compute metrics. Args: data: Single dataset or list of datasets to evaluate. Must contain required input and output data types. Returns: For single dataset: Dictionary of model types to metric results For multiple datasets: List of metric dictionaries, one per dataset Raises: ValueError: If data is invalid type or missing required fields ValueError: If task requires multiple datasets but single dataset provided """ # Validate input data type and required fields if isinstance(data, BaseDataset): self.validate(data) elif isinstance(data, list) and all(isinstance(d, BaseDataset) for d in data): for d in data: self.validate(d) else: raise ValueError(f"Invalid data type: {type(data)}") # Check if task requires multiple datasets if self.requires_multiple_datasets and not isinstance(data, list): raise ValueError("This task requires a list of datasets") # Handle single vs multiple datasets if isinstance(data, list) and not self.requires_multiple_datasets: # Process each dataset individually all_metrics = [] for d in data: all_metrics.append(self._run_task_for_dataset(d, model_types)) return all_metrics else: # Process single dataset or multiple datasets as required by the task return self._run_task_for_dataset(data, model_types)