import logging
from typing import Annotated, List
from pydantic import Field, field_validator
import pandas as pd
import scipy.sparse as sp
from czbenchmarks.types import ListLike
from ..constants import RANDOM_SEED
from ..metrics.types import MetricResult, MetricType
from .task import PCABaselineInput, Task, TaskInput, TaskOutput
from .types import CellRepresentation
logger = logging.getLogger(__name__)
[docs]
class SequentialOrganizationOutput(TaskOutput):
"""Output for sequential organization task."""
# Sequential organization doesn't produce predicted labels like clustering,
# but we store the embedding for metric computation
embedding: CellRepresentation
[docs]
class SequentialOrganizationTask(Task):
"""Task for evaluating sequential consistency in embeddings.
This task computes sequential quality metrics for embeddings using time point labels.
Evaluates how well embeddings preserve sequential organization between cells.
"""
display_name = "Sequential Organization"
description = "Evaluate sequential consistency in embeddings using time point labels and k-NN based metrics."
input_model = SequentialOrganizationTaskInput
baseline_model = PCABaselineInput
def __init__(self, *, random_seed: int = RANDOM_SEED):
super().__init__(random_seed=random_seed)
def _run_task(
self,
cell_representation: CellRepresentation,
task_input: SequentialOrganizationTaskInput,
) -> SequentialOrganizationOutput:
"""Runs the sequential evaluation task.
Gets embedding coordinates for metric computation.
Args:
cell_representation: gene expression data or embedding for task
task_input: Pydantic model with inputs for the task
Returns:
SequentialOrganizationOutput: Pydantic model with embedding data
"""
# Store the cell representation (embedding) for metric computation
return SequentialOrganizationOutput(embedding=cell_representation)
def _compute_metrics(
self,
task_input: SequentialOrganizationTaskInput,
task_output: SequentialOrganizationOutput,
) -> List[MetricResult]:
"""Computes sequential consistency metrics.
Args:
task_input: Pydantic model with inputs for the task
task_output: Pydantic model with outputs from _run_task
Returns:
List of MetricResult objects containing sequential metrics
"""
from ..metrics import metrics_registry
logger.debug("SequentialOrganizationTask._compute_metrics: Computing metrics")
logger.debug(
f"SequentialOrganizationTask._compute_metrics: embedding shape={task_output.embedding.shape}, labels shape={task_input.input_labels.shape}"
)
results = []
embedding = task_output.embedding
labels = task_input.input_labels
# Convert sparse matrix to dense if needed for JAX compatibility in metrics
if sp.issparse(embedding):
embedding = embedding.toarray()
# Embedding Silhouette Score with sequential labels
results.append(
MetricResult(
metric_type=MetricType.SILHOUETTE_SCORE,
value=metrics_registry.compute(
MetricType.SILHOUETTE_SCORE,
X=embedding,
labels=labels,
),
params={},
)
)
# Sequential alignment
results.append(
MetricResult(
metric_type=MetricType.SEQUENTIAL_ALIGNMENT,
value=metrics_registry.compute(
MetricType.SEQUENTIAL_ALIGNMENT,
X=embedding,
labels=labels,
k=task_input.k,
normalize=task_input.normalize,
adaptive_k=task_input.adaptive_k,
),
params={},
)
)
return results