Source code for czbenchmarks.cli.runner

import logging
from typing import Any, Dict, List, Optional, Union

from anndata import AnnData

from czbenchmarks.tasks.task import TASK_REGISTRY
from czbenchmarks.tasks.types import CellRepresentation

from ..constants import RANDOM_SEED
from .resolve_reference import (
    AnnDataReference,
    is_anndata_reference,
    resolve_value_recursively,
)

logger = logging.getLogger(__name__)


[docs] def run_task( task_name: str, *, adata: AnnData, cell_representation: Union[str, CellRepresentation], run_baseline: bool = False, baseline_params: Optional[Dict[str, Any]] = None, task_params: Optional[Dict[str, Any]] = None, random_seed: int = RANDOM_SEED, ) -> List[Dict[str, Any]]: logger.info(f"Preparing to run task: '{task_name}' (CLI)") baseline_params = baseline_params or {} task_params = task_params or {} TaskClass = TASK_REGISTRY.get_task_class(task_name) InputModel = TaskClass.input_model task_instance = TaskClass(random_seed=random_seed) cell_representation_for_execution = cell_representation if is_anndata_reference(cell_representation): cell_representation_for_execution = AnnDataReference.parse( cell_representation ).resolve(adata) params_resolved = resolve_value_recursively(task_params, adata) baseline_resolved = resolve_value_recursively(baseline_params, adata) if run_baseline: logger.info(f"Computing baseline for '{task_name}'...") try: cell_representation_for_execution = task_instance.compute_baseline( expression_data=cell_representation_for_execution, **baseline_resolved ) logger.info("Baseline computation complete.") except NotImplementedError: logger.warning( f"Baseline calculation is not implemented for '{task_name}'." ) except Exception as e: logger.error(f"Error during baseline computation for '{task_name}': {e}") raise TASK_REGISTRY.validate_task_input(task_name, params_resolved) try: task_input = InputModel(**params_resolved) except Exception as e: logger.error( f"Failed to create TaskInput for '{task_name}' with params {params_resolved}. Error: {e}" ) raise ValueError(f"Invalid task parameters for '{task_name}': {e}") from e logger.info(f"Executing task logic for '{task_name}' (CLI)...") results = task_instance.run( cell_representation=cell_representation_for_execution, task_input=task_input ) logger.info(f"Task '{task_name}' execution complete.") return [res.model_dump() for res in results]