from __future__ import annotations
import inspect
import typing
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Type, Union
import anndata as ad
from pydantic import BaseModel, ValidationError
from pydantic.fields import PydanticUndefined
from ..constants import RANDOM_SEED
from ..metrics.types import MetricResult
from .types import CellRepresentation
from .utils import run_standard_scrna_workflow
[docs]
class TaskOutput(BaseModel):
"""Base class for task outputs."""
model_config = {"arbitrary_types_allowed": True}
[docs]
class TaskParameter(BaseModel):
"""Schema for a single, discoverable parameter."""
type: Any
stringified_type: str
default: Any = None
required: bool
[docs]
class TaskInfo(BaseModel):
"""Schema for all discoverable information about a single benchmark task."""
name: str
display_name: str
description: str
task_params: Dict[str, TaskParameter]
baseline_params: Dict[str, TaskParameter]
[docs]
class TaskRegistry:
"""A registry that is populated automatically as Task subclasses are defined."""
def __init__(self):
self._registry: Dict[str, Type["Task"]] = {}
self._info: Dict[str, TaskInfo] = {}
[docs]
def register_task(self, task_class: type[Task]):
"""Registers a task class and introspects it to gather metadata."""
if inspect.isabstract(task_class) or not hasattr(task_class, "display_name"):
print(
f"Error: Task class {task_class.__name__} missing display_name or is abstract."
)
return
key = (
getattr(task_class, "display_name", task_class.__name__)
.lower()
.replace(" ", "_")
)
self._registry[key] = task_class
self._info[key] = self._introspect_task(task_class)
def _stringify_type(self, annotation: Any) -> str:
"""Return a string representation of a type annotation."""
try:
return str(annotation).replace("typing.", "")
except Exception:
return str(annotation)
def _introspect_task(self, task_class: type[Task]) -> TaskInfo:
"""Extracts parameter and metric information from a task class."""
try:
# 1. Get Task Parameters from the associated Pydantic input model
task_params = {}
if hasattr(task_class, "input_model") and issubclass(
task_class.input_model, BaseModel
):
for (
field_name,
field_info,
) in task_class.input_model.model_fields.items():
type_info = self._extract_type_info(
field_info.annotation, field_name
)
type_str = self._stringify_type(type_info)
task_params[field_name] = TaskParameter(
type=type_info,
stringified_type=type_str,
default=field_info.default
if field_info.default is not PydanticUndefined
else None,
required=field_info.is_required(),
)
# 2. Get Baseline Parameters from the compute_baseline method signature
baseline_params = {}
try:
hints = typing.get_type_hints(
task_class.compute_baseline, include_extras=True
)
sig = inspect.signature(task_class.compute_baseline)
for param in list(sig.parameters.values())[1:]: # Skip 'self'
if param.name in {
"kwargs",
"cell_representation",
"expression_data",
}:
continue
type_info = hints.get(param.name, Any)
type_str = self._stringify_type(type_info)
baseline_params[param.name] = TaskParameter(
type=type_info,
stringified_type=type_str,
default=param.default
if param.default != inspect.Parameter.empty
else None,
required=param.default == inspect.Parameter.empty,
)
except Exception as e:
# If baseline introspection fails, continue without baseline params
print(
f"Warning: Could not introspect baseline parameters for {task_class.__name__}: {e}"
)
# 3. Get additional task metadata
description = self._extract_description(task_class)
display_name = getattr(task_class, "display_name", task_class.__name__)
return TaskInfo(
name=task_class.__name__,
display_name=display_name,
description=description,
task_params=task_params,
baseline_params=baseline_params,
)
except Exception as e:
# Fallback task info if introspection fails
print(f"Warning: Task introspection failed for {task_class.__name__}: {e}")
return TaskInfo(
name=task_class.__name__,
display_name=getattr(task_class, "display_name", task_class.__name__),
description="Task introspection failed - please check task implementation",
task_params={},
baseline_params={},
metrics=[],
)
def _extract_type_info(self, annotation: Any, param_name: str) -> type:
"""Return the actual annotation for downstream strict type checking."""
if annotation == inspect.Parameter.empty:
return Any
return annotation # <-- Just return the annotation itself
def _extract_description(self, task_class: Type["Task"]) -> str:
"""Extract description from task class with fallbacks."""
# Try explicit description attribute
if hasattr(task_class, "description"):
return task_class.description
# Try docstring
doc = inspect.getdoc(task_class)
if doc:
# Extract first paragraph of docstring
first_paragraph = doc.split("\n\n")[0].strip()
return first_paragraph
# Fallback
return f"No description available for {task_class.__name__}"
[docs]
def list_tasks(self) -> List[str]:
"""Returns a list of all available task names."""
return sorted(self._registry.keys())
[docs]
def get_task_info(self, task_name: str) -> TaskInfo:
"""Gets all introspected information for a given task."""
if task_name not in self._info:
raise ValueError(f"Task '{task_name}' not found.")
return self._info[task_name]
[docs]
def get_task_class(self, task_name: str) -> Type["Task"]:
"""Gets the class for a given task name."""
if task_name not in self._registry:
available = ", ".join(self.list_tasks())
raise ValueError(
f"Task '{task_name}' not found. Available tasks: {available}"
)
return self._registry[task_name]
[docs]
def get_task_help(self, task_name: str) -> str:
"""Generate detailed help text for a specific task."""
try:
task_info = self.get_task_info(task_name)
help_text = [
f"Task: {task_info.display_name}",
f"Description: {task_info.description}",
"",
]
if task_info.task_params:
help_text.append("Task Parameters:")
for param_name, param_info in task_info.task_params.items():
required_str = (
"(required)"
if param_info.required
else f"(optional, default: {param_info.default})"
)
help_text.append(
f" --{param_name.replace('_', '-')}: {param_info.type} {required_str}"
)
help_text.append("")
if task_info.baseline_params:
help_text.append("Baseline Parameters (use with --compute-baseline):")
for param_name, param_info in task_info.baseline_params.items():
required_str = (
"(required)"
if param_info.required
else f"(optional, default: {param_info.default})"
)
help_text.append(
f" --baseline-{param_name.replace('_', '-')}: {param_info.type} {required_str}"
)
help_text.append("")
return "\n".join(help_text)
except Exception as e:
return f"Error generating help for task '{task_name}': {e}"
[docs]
def validate_task_parameters(
self, task_name: str, parameters: Dict[str, Any]
) -> List[str]:
"""Validate parameters for a task and return list of error messages."""
errors = []
try:
task_info = self.get_task_info(task_name)
# Check for unknown parameters
known_params = set(task_info.task_params.keys())
provided_params = set(parameters.keys())
unknown_params = provided_params - known_params
for param in unknown_params:
errors.append(
f"Unknown parameter '{param}'. Available parameters: {list(known_params)}"
)
# Check for missing required parameters
for param_name, param_info in task_info.task_params.items():
if param_info.required and param_name not in parameters:
errors.append(f"Missing required parameter '{param_name}'")
except Exception as e:
errors.append(f"Error validating parameters: {e}")
return errors
# Global singleton instance, ready for import by other modules.
TASK_REGISTRY = TaskRegistry()
[docs]
class Task(ABC):
"""Abstract base class for all benchmark tasks.
Defines the interface that all tasks must implement. Tasks are responsible for:
1. Declaring their required input/output data types
2. Running task-specific computations
3. Computing evaluation metrics
Tasks should store any intermediate results as instance variables
to be used in metric computation.
Args:
random_seed (int): Random seed for reproducibility
"""
def __init__(
self,
*,
random_seed: int = RANDOM_SEED,
):
self.random_seed = random_seed
# FIXME should this be changed to requires_multiple_embeddings?
self.requires_multiple_datasets = False
[docs]
def __init_subclass__(cls, **kwargs):
"""Automatically register task subclasses when they are defined."""
super().__init_subclass__(**kwargs)
TASK_REGISTRY.register_task(cls)
@abstractmethod
def _run_task(
self, cell_representation: CellRepresentation, task_input: TaskInput
) -> TaskOutput:
"""Run the task's core computation.
Should store any intermediate results needed for metric computation
as instance variables.
Args:
cell_representation: gene expression data or embedding for task
task_input: Pydantic model with inputs for the task
Returns:
TaskOutput: Pydantic model with output data for the task
"""
@abstractmethod
def _compute_metrics(
self, task_input: TaskInput, task_output: TaskOutput
) -> List[MetricResult]:
"""Compute evaluation metrics for the task.
Returns:
List of MetricResult objects containing metric values and metadata
"""
def _run_task_for_dataset(
self,
cell_representation: CellRepresentation,
task_input: TaskInput,
) -> List[MetricResult]:
"""Run task for a dataset or list of datasets and compute metrics.
This method runs the task implementation and computes the corresponding metrics.
Args:
cell_representation: gene expression data or embedding for task
task_input: Pydantic model with inputs for the task
Returns:
List of MetricResult objects
"""
task_output = self._run_task(cell_representation, task_input)
metrics = self._compute_metrics(task_input, task_output)
return metrics
[docs]
def compute_baseline(
self,
expression_data: CellRepresentation,
**kwargs,
) -> CellRepresentation:
"""Set a baseline embedding using PCA on gene expression data.
This method performs standard preprocessing on the raw gene expression data
and uses PCA for dimensionality reduction. It then sets the PCA embedding
as the BASELINE model output in the dataset, which can be used for comparison
with other model embeddings.
Args:
expression_data: expression data to use for anndata
**kwargs: Additional arguments passed to run_standard_scrna_workflow
"""
# Create the AnnData object
adata = ad.AnnData(X=expression_data)
# Run the standard preprocessing workflow
embedding_baseline = run_standard_scrna_workflow(adata, **kwargs)
return embedding_baseline
[docs]
def run(
self,
cell_representation: Union[CellRepresentation, List[CellRepresentation]],
task_input: TaskInput,
) -> List[MetricResult]:
"""Run the task on input data and compute metrics.
Args:
cell_representation: gene expression data or embedding to use for the task
task_input: Pydantic model with inputs for the task
Returns:
For single embedding: A one-element list containing a single metric result for the task
For multiple embeddings: List of metric results for each task, one per dataset
Raises:
ValueError: If input does not match multiple embedding requirement
"""
# Check if task requires embeddings from multiple datasets
if self.requires_multiple_datasets:
error_message = "This task requires a list of cell representations"
if not isinstance(cell_representation, list):
raise ValueError(error_message)
if not all(
[isinstance(emb, CellRepresentation) for emb in cell_representation]
):
raise ValueError(error_message)
if len(cell_representation) < 2:
raise ValueError(f"{error_message} but only one was provided")
else:
if not isinstance(cell_representation, CellRepresentation):
raise ValueError(
"This task requires a single cell representation for input"
)
return self._run_task_for_dataset(
cell_representation, # type: ignore
task_input,
)