Add a Custom Task
This guide explains how to create and integrate your own evaluation task into cz-benchmarks.
Overview of the Task System
Required Components
Component |
Description |
Example |
---|---|---|
|
Human-readable name for the task |
|
|
Pydantic class defining task input schema |
|
|
Core task implementation method |
Returns |
|
Metric computation method |
Returns |
Optional Components
Component |
Description |
Default |
---|---|---|
|
One-sentence task description |
Extracted from docstring |
|
Baseline implementation for comparison |
Raises |
|
Flag for multi-dataset tasks |
|
Key Features
Automatic Registration: Tasks are automatically registered by the Task base class (in
__init_subclass__
)Arbitrary Types: Input/Output models support complex objects (DataFrames, arrays) via
model_config = {"arbitrary_types_allowed": True}
Type Safety: Full Pydantic validation for all inputs
Step-by-Step Implementation
1. Define Input and Output Models
Create Pydantic models that inherit from TaskInput
and TaskOutput
:
from czbenchmarks.tasks import TaskInput, TaskOutput
from czbenchmarks.types import ListLike
from typing import List
import pandas as pd
class MyTaskInput(TaskInput):
"""Input model for MyTask."""
ground_truth_labels: ListLike
metadata: pd.DataFrame # Example of arbitrary type support
class MyTaskOutput(TaskOutput):
"""Output model for MyTask."""
predictions: List[float]
confidence_scores: List[float]
2. Implement the Task Class
Create a new file in src/czbenchmarks/tasks/
(e.g., my_task.py
):
import logging
from typing import List
import numpy as np
from ..constants import RANDOM_SEED
from ..metrics.types import MetricResult, MetricType
from ..metrics import metrics_registry
from .task import Task, TaskInput, TaskOutput
from .types import CellRepresentation
logger = logging.getLogger(__name__)
class MyTaskInput(TaskInput):
"""Input model for MyTask."""
ground_truth_labels: ListLike
class MyTaskOutput(TaskOutput):
"""Output model for MyTask."""
predictions: List[float]
class MyTask(Task):
"""Example task that demonstrates the basic task implementation pattern.
This task performs a simple prediction based on cell embeddings
and evaluates against ground truth labels.
"""
# REQUIRED: Class attributes for task metadata
display_name = "My Example Task"
description = "Predicts numeric labels from cell embeddings using a simple algorithm."
input_model = MyTaskInput
def __init__(self, my_param: int = 10, *, random_seed: int = RANDOM_SEED):
"""Initialize the task with custom parameters.
Args:
my_param: Custom parameter for the task
random_seed: Random seed for reproducibility
"""
super().__init__(random_seed=random_seed)
self.my_param = my_param
logger.info(f"Initialized {self.display_name} with my_param={my_param}")
def _run_task(
self,
cell_representation: CellRepresentation,
task_input: MyTaskInput,
) -> MyTaskOutput:
"""Core task implementation.
Args:
cell_representation: Cell embeddings or gene expression data
task_input: Validated input parameters
Returns:
Task output containing predictions
"""
logger.info(f"Running task on {len(task_input.ground_truth_labels)} samples")
# Example implementation - replace with your logic
np.random.seed(self.random_seed)
predictions = np.random.random(len(task_input.ground_truth_labels)).tolist()
return MyTaskOutput(predictions=predictions)
def _compute_metrics(
self,
task_input: MyTaskInput,
task_output: MyTaskOutput,
) -> List[MetricResult]:
"""Compute evaluation metrics.
Args:
task_input: Original task input
task_output: Results from _run_task
Returns:
List of metric results
"""
# Use metrics registry to compute standard metrics
metrics = []
# Example: Compute correlation if applicable
if len(task_input.ground_truth_labels) == len(task_output.predictions):
correlation = metrics_registry.compute(
MetricType.PEARSON_CORRELATION,
y_true=task_input.ground_truth_labels,
y_pred=task_output.predictions,
)
metrics.append(
MetricResult(
metric_type=MetricType.PEARSON_CORRELATION,
value=correlation,
params={"my_param": self.my_param},
)
)
return metrics
def compute_baseline(
self,
expression_data: CellRepresentation,
**kwargs,
) -> CellRepresentation:
"""Optional: Compute baseline embedding using standard preprocessing.
Args:
expression_data: Raw gene expression data
**kwargs: Additional parameters for baseline computation
Returns:
Baseline embedding for comparison
"""
# Use the parent class implementation for PCA baseline
return super().compute_baseline(expression_data, **kwargs)
3. Register the Task
Add your task to src/czbenchmarks/tasks/__init__.py
:
# Add these imports
from .my_task import MyTask, MyTaskInput, MyTaskOutput
# Add to __all__ list
__all__ = [
# ... existing exports ...
"MyTask",
"MyTaskInput",
"MyTaskOutput",
]
Note: Registration happens automatically when the class is defined thanks to __init_subclass__
. Adding to __init__.py
makes it easily importable.
4. Test Your Task
# Test script example
from czbenchmarks.tasks import MyTask, MyTaskInput
import numpy as np
# Create test data
cell_rep = np.random.random((100, 50)) # 100 cells, 50 features
task_input = MyTaskInput(ground_truth_labels=np.random.randint(0, 3, 100))
# Run task
task = MyTask(my_param=5, random_seed=42)
results = task.run(cell_rep, task_input)
print(f"Computed {len(results)} metrics")
5. Update Documentation
Add your task to docs/source/developer_guides/tasks.md
:
### Available Tasks
- **My Example Task** – Predicts numeric labels from cell embeddings using a simple algorithm.
See: `czbenchmarks.tasks.my_task.MyTask`
Advanced Features
Multi-Dataset Tasks
For tasks requiring multiple datasets (e.g., integration tasks):
def __init__(self, *, random_seed: int = RANDOM_SEED):
super().__init__(random_seed=random_seed)
self.requires_multiple_datasets = True # Enable multi-dataset mode
Custom Baseline Parameters
Document baseline parameters in the method signature:
def compute_baseline(
self,
expression_data: CellRepresentation,
n_components: int = 50,
highly_variable_genes: bool = True,
**kwargs,
) -> CellRepresentation:
"""Compute PCA baseline with custom parameters."""
return super().compute_baseline(
expression_data,
n_components=n_components,
highly_variable_genes=highly_variable_genes,
**kwargs
)
Task Discovery and CLI Integration
Tasks are automatically discovered via the TASK_REGISTRY
:
from czbenchmarks.tasks import TASK_REGISTRY
# List all available tasks
print(TASK_REGISTRY.list_tasks())
# Get task information
info = TASK_REGISTRY.get_task_info("my_example_task")
print(f"Description: {info.description}")
print(f"Parameters: {list(info.task_params.keys())}")
Tips
✅ Single Responsibility: Each task should solve one well-defined problem
✅ Reproducibility: Pass
self.random_seed
to any library function calls that have stochastic behavior✅ Type Safety: Use explicit type hints throughout
✅ Logging: Log key steps for debugging (
logger.info
,logger.debug
)✅ Error Handling: Provide informative error messages
✅ Documentation: Clear docstrings for all public methods
✅ Testing: Unit tests for input validation, core logic, and metrics
✅ Performance: Consider memory usage for large datasets