from typing import Annotated, List
import logging
import numpy as np
from pydantic import Field, field_validator
from czbenchmarks.datasets.types import Organism
from ...constants import RANDOM_SEED
from ...metrics import metrics_registry
from ...metrics.types import MetricResult, MetricType
from ...tasks.types import CellRepresentation
from ...types import ListLike
from ..task import NoBaselineInput, Task, TaskInput, TaskOutput
logger = logging.getLogger(__name__)
[docs]
class CrossSpeciesIntegrationOutput(TaskOutput):
"""Output for cross-species integration task."""
cell_representation: CellRepresentation
labels: ListLike
species: ListLike
[docs]
class CrossSpeciesIntegrationTask(Task):
"""Task for evaluating cross-species integration quality.
This task computes metrics to assess how well different species' data are integrated
in the embedding space while preserving biological signals. It operates on multiple
datasets from different species.
"""
display_name = "Cross-species Integration"
description = (
"Evaluate cross-species integration quality using various integration metrics."
)
input_model = CrossSpeciesIntegrationTaskInput
baseline_model = NoBaselineInput
def __init__(self, *, random_seed: int = RANDOM_SEED):
super().__init__(random_seed=random_seed)
self.requires_multiple_datasets = True
def _run_task(
self,
cell_representation: CellRepresentation,
task_input: CrossSpeciesIntegrationTaskInput,
) -> CrossSpeciesIntegrationOutput:
"""Runs the cross-species integration evaluation task.
Gets embedding coordinates and labels from multiple datasets and combines them
for metric computation.
Args:
cell_representation: list of cell representations for the task
task_input: Pydantic model with inputs for the task
Returns:
CrossSpeciesIntegrationOutput: Pydantic model with combined data and labels
"""
logger.debug(
f"CrossSpeciesIntegrationTask._run_task: cell_representation type={type(cell_representation)}, "
f"n_datasets={len(cell_representation) if isinstance(cell_representation, list) else 1}"
)
# FIXME BYODATASETdatasets should be concatenated to align along genes?
# This operation is safe because requires_multiple_datasets is True
cell_representation = np.vstack(cell_representation)
# FIXME BYODATASET move this into validation
if len(set(task_input.organism_list)) < 2:
raise AssertionError(
"At least two organisms are required for cross-species integration "
f"but got {len(set(task_input.organism_list))} : {set(task_input.organism_list)}"
)
species = np.concatenate(
[
[
str(organism),
]
* len(label)
for organism, label in zip(task_input.organism_list, task_input.labels)
]
)
labels = np.concatenate(task_input.labels)
if (len(cell_representation) != len(species)) or (len(species) != len(labels)):
raise AssertionError(
"Cell representation, species, and labels must have the same shape"
)
return CrossSpeciesIntegrationOutput(
cell_representation=cell_representation,
labels=labels,
species=species,
)
def _compute_metrics(
self,
_: CrossSpeciesIntegrationTaskInput,
task_output: CrossSpeciesIntegrationOutput,
) -> List[MetricResult]:
"""Computes batch integration quality metrics.
Args:
_: (unused) Pydantic model with input for the task
task_output: Pydantic model with outputs from _run_task
Returns:
List of MetricResult objects containing entropy per cell and
batch-aware silhouette scores
"""
entropy_per_cell_metric = MetricType.ENTROPY_PER_CELL
silhouette_batch_metric = MetricType.BATCH_SILHOUETTE
cell_representation = task_output.cell_representation
labels = task_output.labels
species = task_output.species
return [
MetricResult(
metric_type=entropy_per_cell_metric,
value=metrics_registry.compute(
entropy_per_cell_metric,
X=cell_representation,
labels=species,
random_seed=self.random_seed,
),
),
MetricResult(
metric_type=silhouette_batch_metric,
value=metrics_registry.compute(
silhouette_batch_metric,
X=cell_representation,
labels=labels,
batch=species,
),
),
]
[docs]
def compute_baseline(
self,
expression_data: CellRepresentation,
baseline_input: NoBaselineInput = None,
):
"""Set a baseline embedding for cross-species integration.
Not implemented as standard preprocessing is not applicable across species.
"""
raise NotImplementedError(
"Baseline not implemented for cross-species integration"
)