Source code for czbenchmarks.tasks.single_cell.perturbation

from typing import Literal, Set, List
import pandas as pd
import scanpy as sc
import anndata as ad
import numpy as np
import logging
from ..base import BaseTask
from ...datasets import PerturbationSingleCellDataset, DataType
from ...metrics import metrics_registry
from ...metrics.types import MetricResult, MetricType
from ...models.types import ModelType

logger = logging.getLogger(__name__)


[docs] class PerturbationTask(BaseTask): """Task for evaluating perturbation prediction quality. This task computes metrics to assess how well a model predicts gene expression changes in response to perturbations. Compares predicted vs ground truth perturbation effects using MSE and correlation metrics. """ @property def display_name(self) -> str: """A pretty name to use when displaying task results""" return "perturbation" @property def required_inputs(self) -> Set[DataType]: """Required input data types. Returns: Set of required input DataTypes (ground truth perturbation effects) """ return {DataType.PERTURBATION_TRUTH} @property def required_outputs(self) -> Set[DataType]: """Required output data types. Returns: required output types from models this task to run (predicted perturbation effects) """ return {DataType.PERTURBATION_PRED} def _run_task(self, data: PerturbationSingleCellDataset, model_type: ModelType): """Runs the perturbation evaluation task. Gets predicted perturbation effects, ground truth effects, and control expression from the dataset for metric computation. Args: data: Dataset containing perturbation predictions and ground truth """ self.gene_pert, self.perturbation_pred = data.get_output( model_type, DataType.PERTURBATION_PRED ) self.perturbation_truth = data.perturbation_truth self.perturbation_ctrl = data.adata.X.toarray() self.avg_perturbation_ctrl = pd.Series( data=self.perturbation_ctrl.mean(0), index=data.adata.var_names, name="ctrl", ) def _compute_metrics(self) -> List[MetricResult]: """Computes perturbation prediction quality metrics. For each perturbation, computes: - MSE between predicted and true expression - Correlation between predicted and true expression changes from control Returns: List of MetricResult objects containing metric values and metadata """ avg_perturbation_control = self.avg_perturbation_ctrl mean_squared_error_metric = MetricType.MEAN_SQUARED_ERROR pearson_correlation_metric = MetricType.PEARSON_CORRELATION jaccard_metric = MetricType.JACCARD if self.gene_pert in self.perturbation_truth.keys(): # Run differential expression analysis between control and predicted/truth # Create AnnData objects for control, prediction, and truth adata_ctrl = ad.AnnData(X=self.perturbation_ctrl) adata_pred = ad.AnnData(X=self.perturbation_pred.values) adata_truth = ad.AnnData(X=self.perturbation_truth[self.gene_pert].values) # Ensure they have the same var_names genes = self.perturbation_pred.columns adata_ctrl.var_names = genes adata_pred.var_names = genes adata_truth.var_names = genes # Create combined AnnData for control vs prediction adata_ctrl_pred = ad.AnnData( X=np.vstack([adata_ctrl.X, adata_pred.X]), obs={ "condition": ["ctrl"] * adata_ctrl.n_obs + ["pred"] * adata_pred.n_obs }, ) adata_ctrl_pred.var_names = genes # Create combined AnnData for control vs truth adata_ctrl_truth = ad.AnnData( X=np.vstack([adata_ctrl.X, adata_truth.X]), obs={ "condition": ["ctrl"] * adata_ctrl.n_obs + ["truth"] * adata_truth.n_obs }, ) adata_ctrl_truth.var_names = genes # Run rank_genes_groups for control vs prediction sc.tl.rank_genes_groups( adata_ctrl_pred, groupby="condition", groups=["pred"], reference="ctrl", method="wilcoxon", ) # Run rank_genes_groups for control vs truth sc.tl.rank_genes_groups( adata_ctrl_truth, groupby="condition", groups=["truth"], reference="ctrl", method="wilcoxon", ) # Store the results for later use if needed self.de_results_pred = sc.get.rank_genes_groups_df( adata_ctrl_pred, group="pred" ) self.de_results_truth = sc.get.rank_genes_groups_df( adata_ctrl_truth, group="truth" ) avg_perturbation_pred = self.perturbation_pred.mean(axis=0) avg_perturbation_truth = self.perturbation_truth[self.gene_pert].mean( axis=0 ) intersecting_genes = list( set(avg_perturbation_pred.index) & set(avg_perturbation_truth.index) & set(avg_perturbation_control.index) ) # 1. Calculate metrics for all genes mse_all = metrics_registry.compute( mean_squared_error_metric, y_true=avg_perturbation_truth[intersecting_genes], y_pred=avg_perturbation_pred[intersecting_genes], ) delta_pearson_corr_all = metrics_registry.compute( pearson_correlation_metric, x=avg_perturbation_truth[intersecting_genes] - avg_perturbation_control[intersecting_genes], y=avg_perturbation_pred[intersecting_genes] - avg_perturbation_control[intersecting_genes], ).statistic # 2. Calculate metrics for top 20 DE genes top20_de_genes = ( self.de_results_truth.sort_values("scores", ascending=False) .head(20)["names"] .tolist() ) top20_de_genes = [ gene for gene in top20_de_genes if gene in intersecting_genes ] mse_top20 = metrics_registry.compute( mean_squared_error_metric, y_true=avg_perturbation_truth[top20_de_genes], y_pred=avg_perturbation_pred[top20_de_genes], ) delta_pearson_corr_top20 = metrics_registry.compute( pearson_correlation_metric, x=avg_perturbation_truth[top20_de_genes] - avg_perturbation_control[top20_de_genes], y=avg_perturbation_pred[top20_de_genes] - avg_perturbation_control[top20_de_genes], ).statistic # 3. Calculate metrics for top 100 DE genes top100_de_genes = ( self.de_results_truth.sort_values("scores", ascending=False) .head(100)["names"] .tolist() ) top100_de_genes = [ gene for gene in top100_de_genes if gene in intersecting_genes ] mse_top100 = metrics_registry.compute( mean_squared_error_metric, y_true=avg_perturbation_truth[top100_de_genes], y_pred=avg_perturbation_pred[top100_de_genes], ) delta_pearson_corr_top100 = metrics_registry.compute( pearson_correlation_metric, x=avg_perturbation_truth[top100_de_genes] - avg_perturbation_control[top100_de_genes], y=avg_perturbation_pred[top100_de_genes] - avg_perturbation_control[top100_de_genes], ).statistic # Calculate Jaccard similarity for top DE genes top20_pred_de_genes = set( self.de_results_pred.sort_values("scores", ascending=False) .head(20)["names"] .tolist() ) top20_truth_de_genes = set( self.de_results_truth.sort_values("scores", ascending=False) .head(20)["names"] .tolist() ) jaccard_top20 = metrics_registry.compute( jaccard_metric, y_true=top20_truth_de_genes, y_pred=top20_pred_de_genes, ) top100_pred_de_genes = set( self.de_results_pred.sort_values("scores", ascending=False) .head(100)["names"] .tolist() ) top100_truth_de_genes = set( self.de_results_truth.sort_values("scores", ascending=False) .head(100)["names"] .tolist() ) jaccard_top100 = metrics_registry.compute( jaccard_metric, y_true=top100_truth_de_genes, y_pred=top100_pred_de_genes, ) return [ MetricResult( metric_type=mean_squared_error_metric, value=mse_all, params={"subset": "all"}, ), MetricResult( metric_type=pearson_correlation_metric, value=delta_pearson_corr_all, params={"subset": "all"}, ), MetricResult( metric_type=mean_squared_error_metric, value=mse_top20, params={"subset": "top20"}, ), MetricResult( metric_type=pearson_correlation_metric, value=delta_pearson_corr_top20, params={"subset": "top20"}, ), MetricResult( metric_type=mean_squared_error_metric, value=mse_top100, params={"subset": "top100"}, ), MetricResult( metric_type=pearson_correlation_metric, value=delta_pearson_corr_top100, params={"subset": "top100"}, ), MetricResult( metric_type=jaccard_metric, value=jaccard_top20, params={"subset": "top20"}, ), MetricResult( metric_type=jaccard_metric, value=jaccard_top100, params={"subset": "top100"}, ), ] else: raise ValueError( f"Perturbation {self.gene_pert} is not available in the ground truth " "test perturbations." )
[docs] def set_baseline( self, data: PerturbationSingleCellDataset, gene_pert: str, baseline_type: Literal["median", "mean"] = "median", **kwargs, ): """Set a baseline embedding for perturbation prediction. Creates baseline predictions using simple statistical methods (median and mean) applied to the control data, and evaluates these predictions against ground truth. Args: data: PerturbationSingleCellDataset containing control and perturbed data gene_pert: The perturbation gene to evaluate baseline_type: The statistical method to use for baseline prediction (median or mean) **kwargs: Additional arguments passed to the evaluation Returns: List of MetricResult objects containing baseline performance metrics for different statistical methods (median, mean) """ # Iterate through different statistical baseline functions (median and mean) # Create baseline prediction by replicating the aggregated expression values # across all cells in the dataset. baseline_func = np.median if baseline_type == "median" else np.mean perturb_baseline_pred = pd.DataFrame( np.tile( baseline_func(data.adata.X.toarray(), axis=0), (data.adata.shape[0], 1) ), columns=data.adata.var_names, # Use gene names from the dataset index=data.adata.obs_names, # Use cell names from the dataset ) # Store the baseline prediction in the dataset for evaluation data.set_output( ModelType.BASELINE, DataType.PERTURBATION_PRED, (gene_pert, perturb_baseline_pred), )