Source code for czbenchmarks.tasks.sequential

import logging
from typing import Annotated, List
from pydantic import Field, field_validator

import pandas as pd
import scipy.sparse as sp

from czbenchmarks.types import ListLike

from ..constants import RANDOM_SEED
from ..metrics.types import MetricResult, MetricType
from .task import PCABaselineInput, Task, TaskInput, TaskOutput
from .types import CellRepresentation

logger = logging.getLogger(__name__)


[docs] class SequentialOrganizationTaskInput(TaskInput): """Pydantic model for Sequential Organization inputs.""" obs: Annotated[ pd.DataFrame, Field( description="Cell metadata DataFrame (e.g. the `obs` from an AnnData object)." ), ] input_labels: Annotated[ ListLike, Field( description="Ground truth labels for metric calculation (e.g. `obs.cell_type` from an AnnData object)." ), ] k: Annotated[ int, Field(description="Number of nearest neighbors for k-NN based metrics.") ] = 15 normalize: Annotated[ bool, Field(description="Whether to normalize the embedding for k-NN based metrics."), ] = True adaptive_k: Annotated[ bool, Field( description="Whether to use an adaptive number of nearest neighbors for k-NN based metrics." ), ] = False @field_validator("k") @classmethod def _validate_k(cls, v: int) -> int: if v <= 0: raise ValueError("k must be a positive integer.") return v
[docs] class SequentialOrganizationOutput(TaskOutput): """Output for sequential organization task.""" # Sequential organization doesn't produce predicted labels like clustering, # but we store the embedding for metric computation embedding: CellRepresentation
[docs] class SequentialOrganizationTask(Task): """Task for evaluating sequential consistency in embeddings. This task computes sequential quality metrics for embeddings using time point labels. Evaluates how well embeddings preserve sequential organization between cells. """ display_name = "Sequential Organization" description = "Evaluate sequential consistency in embeddings using time point labels and k-NN based metrics." input_model = SequentialOrganizationTaskInput baseline_model = PCABaselineInput def __init__(self, *, random_seed: int = RANDOM_SEED): super().__init__(random_seed=random_seed) def _run_task( self, cell_representation: CellRepresentation, task_input: SequentialOrganizationTaskInput, ) -> SequentialOrganizationOutput: """Runs the sequential evaluation task. Gets embedding coordinates for metric computation. Args: cell_representation: gene expression data or embedding for task task_input: Pydantic model with inputs for the task Returns: SequentialOrganizationOutput: Pydantic model with embedding data """ # Store the cell representation (embedding) for metric computation return SequentialOrganizationOutput(embedding=cell_representation) def _compute_metrics( self, task_input: SequentialOrganizationTaskInput, task_output: SequentialOrganizationOutput, ) -> List[MetricResult]: """Computes sequential consistency metrics. Args: task_input: Pydantic model with inputs for the task task_output: Pydantic model with outputs from _run_task Returns: List of MetricResult objects containing sequential metrics """ from ..metrics import metrics_registry logger.debug("SequentialOrganizationTask._compute_metrics: Computing metrics") logger.debug( f"SequentialOrganizationTask._compute_metrics: embedding shape={task_output.embedding.shape}, labels shape={task_input.input_labels.shape}" ) results = [] embedding = task_output.embedding labels = task_input.input_labels # Convert sparse matrix to dense if needed for JAX compatibility in metrics if sp.issparse(embedding): embedding = embedding.toarray() # Embedding Silhouette Score with sequential labels results.append( MetricResult( metric_type=MetricType.SILHOUETTE_SCORE, value=metrics_registry.compute( MetricType.SILHOUETTE_SCORE, X=embedding, labels=labels, ), params={}, ) ) # Sequential alignment results.append( MetricResult( metric_type=MetricType.SEQUENTIAL_ALIGNMENT, value=metrics_registry.compute( MetricType.SEQUENTIAL_ALIGNMENT, X=embedding, labels=labels, k=task_input.k, normalize=task_input.normalize, adaptive_k=task_input.adaptive_k, ), params={}, ) ) return results