Source code for czbenchmarks.tasks.single_cell.cross_species_integration

from typing import Annotated, List
import logging

import numpy as np
from pydantic import Field, field_validator

from czbenchmarks.datasets.types import Organism

from ...constants import RANDOM_SEED
from ...metrics import metrics_registry
from ...metrics.types import MetricResult, MetricType
from ...tasks.types import CellRepresentation
from ...types import ListLike
from ..task import NoBaselineInput, Task, TaskInput, TaskOutput

logger = logging.getLogger(__name__)


[docs] class CrossSpeciesIntegrationTaskInput(TaskInput): """Pydantic model for CrossSpeciesIntegrationTask inputs.""" labels: Annotated[ List[ListLike], Field( description="List of ground truth labels for each species dataset (e.g., cell types)." ), ] organism_list: Annotated[ List[Organism], Field( description="List of organisms corresponding to each dataset for cross-species evaluation." ), ] @field_validator("organism_list") @classmethod def _validate_organism_list(cls, v: List[Organism]) -> List[Organism]: if not isinstance(v, list): raise ValueError("organism_list must be a list of organisms.") return v
[docs] class CrossSpeciesIntegrationOutput(TaskOutput): """Output for cross-species integration task.""" cell_representation: CellRepresentation labels: ListLike species: ListLike
[docs] class CrossSpeciesIntegrationTask(Task): """Task for evaluating cross-species integration quality. This task computes metrics to assess how well different species' data are integrated in the embedding space while preserving biological signals. It operates on multiple datasets from different species. """ display_name = "Cross-species Integration" description = ( "Evaluate cross-species integration quality using various integration metrics." ) input_model = CrossSpeciesIntegrationTaskInput baseline_model = NoBaselineInput def __init__(self, *, random_seed: int = RANDOM_SEED): super().__init__(random_seed=random_seed) self.requires_multiple_datasets = True def _run_task( self, cell_representation: CellRepresentation, task_input: CrossSpeciesIntegrationTaskInput, ) -> CrossSpeciesIntegrationOutput: """Runs the cross-species integration evaluation task. Gets embedding coordinates and labels from multiple datasets and combines them for metric computation. Args: cell_representation: list of cell representations for the task task_input: Pydantic model with inputs for the task Returns: CrossSpeciesIntegrationOutput: Pydantic model with combined data and labels """ logger.debug( f"CrossSpeciesIntegrationTask._run_task: cell_representation type={type(cell_representation)}, " f"n_datasets={len(cell_representation) if isinstance(cell_representation, list) else 1}" ) # FIXME BYODATASETdatasets should be concatenated to align along genes? # This operation is safe because requires_multiple_datasets is True cell_representation = np.vstack(cell_representation) # FIXME BYODATASET move this into validation if len(set(task_input.organism_list)) < 2: raise AssertionError( "At least two organisms are required for cross-species integration " f"but got {len(set(task_input.organism_list))} : {set(task_input.organism_list)}" ) species = np.concatenate( [ [ str(organism), ] * len(label) for organism, label in zip(task_input.organism_list, task_input.labels) ] ) labels = np.concatenate(task_input.labels) if (len(cell_representation) != len(species)) or (len(species) != len(labels)): raise AssertionError( "Cell representation, species, and labels must have the same shape" ) return CrossSpeciesIntegrationOutput( cell_representation=cell_representation, labels=labels, species=species, ) def _compute_metrics( self, _: CrossSpeciesIntegrationTaskInput, task_output: CrossSpeciesIntegrationOutput, ) -> List[MetricResult]: """Computes batch integration quality metrics. Args: _: (unused) Pydantic model with input for the task task_output: Pydantic model with outputs from _run_task Returns: List of MetricResult objects containing entropy per cell and batch-aware silhouette scores """ entropy_per_cell_metric = MetricType.ENTROPY_PER_CELL silhouette_batch_metric = MetricType.BATCH_SILHOUETTE cell_representation = task_output.cell_representation labels = task_output.labels species = task_output.species return [ MetricResult( metric_type=entropy_per_cell_metric, value=metrics_registry.compute( entropy_per_cell_metric, X=cell_representation, labels=species, random_seed=self.random_seed, ), ), MetricResult( metric_type=silhouette_batch_metric, value=metrics_registry.compute( silhouette_batch_metric, X=cell_representation, labels=labels, batch=species, ), ), ]
[docs] def compute_baseline( self, expression_data: CellRepresentation, baseline_input: NoBaselineInput = None, ): """Set a baseline embedding for cross-species integration. Not implemented as standard preprocessing is not applicable across species. """ raise NotImplementedError( "Baseline not implemented for cross-species integration" )