Source code for czbenchmarks.tasks.label_prediction

import logging
from typing import Any, Dict, List

import pandas as pd
import scipy as sp
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    make_scorer,
    precision_score,
    recall_score,
    roc_auc_score,
)
from sklearn.model_selection import StratifiedKFold, cross_validate
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

from ..constants import RANDOM_SEED
from ..metrics import metrics_registry
from ..metrics.types import MetricResult, MetricType
from ..tasks.types import CellRepresentation
from ..types import ListLike
from .constants import MIN_CLASS_SIZE, N_FOLDS
from .task import Task, TaskInput, TaskOutput
from .utils import filter_minimum_class

logger = logging.getLogger(__name__)


[docs] class MetadataLabelPredictionTaskInput(TaskInput): """Pydantic model for MetadataLabelPredictionTask inputs.""" labels: ListLike n_folds: int = N_FOLDS min_class_size: int = MIN_CLASS_SIZE
[docs] class MetadataLabelPredictionOutput(TaskOutput): """Output for label prediction task.""" results: List[Dict[str, Any]] # List of dicts with classifier, split, and metrics
[docs] class MetadataLabelPredictionTask(Task): """Task for predicting labels from embeddings using cross-validation. Evaluates multiple classifiers (Logistic Regression, KNN) using k-fold cross-validation. Reports standard classification metrics. Args: random_seed (int): Random seed for reproducibility """ display_name = "Label Prediction" description = "Predict labels from embeddings using cross-validated classifiers and standard metrics." input_model = MetadataLabelPredictionTaskInput def __init__( self, *, random_seed: int = RANDOM_SEED, ): super().__init__(random_seed=random_seed) def _run_task( self, cell_representation: CellRepresentation, task_input: MetadataLabelPredictionTaskInput, ) -> MetadataLabelPredictionOutput: """Runs cross-validation prediction task. Evaluates multiple classifiers using k-fold cross-validation on the cell representation data. Stores results for metric computation. Args: cell_representation: gene expression data or embedding for task task_input: Pydantic model with inputs for the task Returns: MetadataLabelPredictionOutput: Pydantic model with results from cross-validation """ # FIXME BYOTASK: this is quite baroque and should be broken into sub-tasks logger.info("Starting prediction task for labels") cell_representation = ( cell_representation.copy() ) # Protect from destructive operations logger.info( f"Initial data shape: {cell_representation.shape}, labels shape: {task_input.labels.shape}" ) # Filter classes with minimum size requirement cell_representation, labels = filter_minimum_class( cell_representation, task_input.labels, min_class_size=task_input.min_class_size, ) logger.info(f"After filtering: {cell_representation.shape} samples remaining") # Determine scoring metrics based on number of classes n_classes = len(labels.unique()) target_type = "binary" if n_classes == 2 else "macro" logger.info( f"Found {n_classes} classes, using {target_type} averaging for metrics" ) scorers = { "accuracy": make_scorer(accuracy_score), "f1": make_scorer(f1_score, average=target_type), "precision": make_scorer(precision_score, average=target_type), "recall": make_scorer(recall_score, average=target_type), "auroc": make_scorer( roc_auc_score, average="macro", multi_class="ovr", response_method="predict_proba", ), } # Setup cross validation skf = StratifiedKFold( n_splits=task_input.n_folds, shuffle=True, random_state=self.random_seed ) logger.info( f"Using {task_input.n_folds}-fold cross validation with random_seed {self.random_seed}" ) # Create classifiers classifiers = { "lr": Pipeline( [("scaler", StandardScaler()), ("lr", LogisticRegression())] ), "knn": Pipeline( [("scaler", StandardScaler()), ("knn", KNeighborsClassifier())] ), "rf": Pipeline( [("rf", RandomForestClassifier(random_state=self.random_seed))] ), } logger.info(f"Created classifiers: {list(classifiers.keys())}") # Store results results = [] # Run cross validation for each classifier labels = pd.Categorical(labels.astype(str)) for name, clf in classifiers.items(): logger.info(f"Running cross-validation for {name}...") cv_results = cross_validate( clf, cell_representation, labels.codes, cv=skf, scoring=scorers, return_train_score=False, ) for fold in range(task_input.n_folds): fold_results = {"classifier": name, "split": fold} for metric in scorers.keys(): fold_results[metric] = cv_results[f"test_{metric}"][fold] results.append(fold_results) logger.debug(f"{name} fold {fold} results: {fold_results}") logger.info("Completed cross-validation for all classifiers") return MetadataLabelPredictionOutput(results=results) def _compute_metrics( self, _: MetadataLabelPredictionTaskInput, task_output: MetadataLabelPredictionOutput, ) -> List[MetricResult]: """Computes classification metrics across all folds. Aggregates results from cross-validation and computes mean metrics per classifier and overall. Args: _: (unused) Pydantic model with input for the task task_output: Pydantic model results from cross-validation Returns: List of MetricResult objects containing mean metrics across all classifiers and per-classifier metrics """ logger.info("Computing final metrics...") results = task_output.results results_df = pd.DataFrame(results) metrics_list = [] classifiers = results_df["classifier"].unique() all_classifier_names = ",".join(sorted(classifiers)) params = {"classifier": f"MEAN({all_classifier_names})"} # Calculate overall metrics across all classifiers metrics_list.extend( [ MetricResult( metric_type=MetricType.MEAN_FOLD_ACCURACY, value=metrics_registry.compute( MetricType.MEAN_FOLD_ACCURACY, results_df=results_df ), params=params, ), MetricResult( metric_type=MetricType.MEAN_FOLD_F1_SCORE, value=metrics_registry.compute( MetricType.MEAN_FOLD_F1_SCORE, results_df=results_df ), params=params, ), MetricResult( metric_type=MetricType.MEAN_FOLD_PRECISION, value=metrics_registry.compute( MetricType.MEAN_FOLD_PRECISION, results_df=results_df ), params=params, ), MetricResult( metric_type=MetricType.MEAN_FOLD_RECALL, value=metrics_registry.compute( MetricType.MEAN_FOLD_RECALL, results_df=results_df ), params=params, ), MetricResult( metric_type=MetricType.MEAN_FOLD_AUROC, value=metrics_registry.compute( MetricType.MEAN_FOLD_AUROC, results_df=results_df ), params=params, ), ] ) # Calculate per-classifier metrics for clf in results_df["classifier"].unique(): params = {"classifier": clf} metrics_list.extend( [ MetricResult( metric_type=MetricType.MEAN_FOLD_ACCURACY, value=metrics_registry.compute( MetricType.MEAN_FOLD_ACCURACY, results_df=results_df, classifier=clf, ), params=params, ), MetricResult( metric_type=MetricType.MEAN_FOLD_F1_SCORE, value=metrics_registry.compute( MetricType.MEAN_FOLD_F1_SCORE, results_df=results_df, classifier=clf, ), params=params, ), MetricResult( metric_type=MetricType.MEAN_FOLD_PRECISION, value=metrics_registry.compute( MetricType.MEAN_FOLD_PRECISION, results_df=results_df, classifier=clf, ), params=params, ), MetricResult( metric_type=MetricType.MEAN_FOLD_RECALL, value=metrics_registry.compute( MetricType.MEAN_FOLD_RECALL, results_df=results_df, classifier=clf, ), params=params, ), MetricResult( metric_type=MetricType.MEAN_FOLD_AUROC, value=metrics_registry.compute( MetricType.MEAN_FOLD_AUROC, results_df=results_df, classifier=clf, ), params=params, ), ] ) return metrics_list
[docs] def compute_baseline( self, expression_data: CellRepresentation, **kwargs, ) -> CellRepresentation: """Set a baseline cell representation using raw gene expression. Instead of using embeddings from a model, this method uses the raw gene expression matrix as features for classification. This provides a baseline performance to compare against model-generated embeddings for classification tasks. Args: expression_data: gene expression data or embedding Returns: Baseline embedding """ # Convert sparse matrix to dense if needed if sp.sparse.issparse(expression_data): expression_data = expression_data.toarray() return expression_data