Source code for czbenchmarks.tasks.integration

import logging
from typing import Set, List

from ..datasets import BaseDataset, DataType
from ..models.types import ModelType
from ..metrics import metrics_registry
from ..metrics.types import MetricResult, MetricType
from .base import BaseTask

logger = logging.getLogger(__name__)


[docs] class BatchIntegrationTask(BaseTask): """Task for evaluating batch integration quality. This task computes metrics to assess how well different batches are integrated in the embedding space while preserving biological signals. Args: label_key: Key to access ground truth cell type labels in metadata batch_key: Key to access batch labels in metadata """ def __init__(self, label_key: str, batch_key: str): self.label_key = label_key self.batch_key = batch_key @property def display_name(self) -> str: """A pretty name to use when displaying task results""" return "batch integration" @property def required_inputs(self) -> Set[DataType]: """Required input data types. Returns: Set of required input DataTypes (metadata with labels) """ return {DataType.METADATA} @property def required_outputs(self) -> Set[DataType]: """Required output data types. Returns: required output types from models this task to run (embedding coordinates) """ return {DataType.EMBEDDING} def _run_task(self, data: BaseDataset, model_type: ModelType): """Runs the batch integration evaluation task. Gets embedding coordinates, batch labels and cell type labels from the dataset for metric computation. Args: data: Dataset containing embedding and labels """ self.embedding = data.get_output(model_type, DataType.EMBEDDING) self.batch_labels = data.get_input(DataType.METADATA)[self.batch_key] self.labels = data.get_input(DataType.METADATA)[self.label_key] def _compute_metrics(self) -> List[MetricResult]: """Computes batch integration quality metrics. Returns: List of MetricResult objects containing entropy per cell and batch-aware silhouette scores """ entropy_per_cell_metric = MetricType.ENTROPY_PER_CELL silhouette_batch_metric = MetricType.BATCH_SILHOUETTE return [ MetricResult( metric_type=entropy_per_cell_metric, value=metrics_registry.compute( entropy_per_cell_metric, X=self.embedding, labels=self.batch_labels, ), ), MetricResult( metric_type=silhouette_batch_metric, value=metrics_registry.compute( silhouette_batch_metric, X=self.embedding, labels=self.labels, batch=self.batch_labels, ), ), ]