Source code for czbenchmarks.datasets.utils_single_cell

from typing import Tuple, Dict, List
import logging
import anndata as ad
import numpy as np
import pandas as pd
import scanpy as sc
from pandas.api.types import is_categorical_dtype
from tqdm.auto import tqdm

logger = logging.getLogger(__name__)


[docs] def run_multicondition_dge_analysis( adata: ad.AnnData, condition_key: str, de_gene_col: str, control_cells_ids: Dict[str, List[str]], filter_min_cells: int = 10, filter_min_genes: int = 1000, min_pert_cells: int = 50, remove_avg_zeros: bool = False, store_dge_metadata: bool = False, return_merged_adata: bool = False, ) -> Tuple[pd.DataFrame, ad.AnnData]: """ Run differential gene expression analysis for a list of conditions between perturbed and matched control cells. Parameters ---------- adata (AnnData): Annotated data matrix containing gene expression and metadata. condition_key (str): Column name for condition labels in `adata.obs`. de_gene_col (str): Column name for gene labels in `adata.var`. control_cells_ids (Dict[str, List[str]]): Mapping from condition -> list of matched control cell ids. filter_min_cells (int, optional): Minimum number of cells expressing a gene to include that gene. Defaults to 10. filter_min_genes (int, optional): Minimum number of genes detected per cell. Defaults to 1000. min_pert_cells (int, optional): Minimum number of perturbed cells required. Defaults to 50. remove_avg_zeros (bool, optional): Whether to remove genes with zero average expression. Defaults to True. store_dge_metadata (bool, optional): Whether to store DGE metadata in the results DataFrame. Defaults to False. return_merged_adata (bool, optional): Whether to return the merged AnnData object. Defaults to False. Returns ------- Tuple[pd.DataFrame, anndata.AnnData] (results_df, adata_merged): - results_df: Differential expression results for `selected_condition`. - adata_merged: AnnData containing concatenated condition and control cells. """ if return_merged_adata: logger.warning( "return_merged_adata is True, which can consume a large amount of memory." ) deg_test_name = "wilcoxon" obs = adata.obs obs_index = obs.index # Optional: ensure categorical for faster grouping if not is_categorical_dtype(obs[condition_key].dtype): obs[condition_key] = obs[condition_key].astype("category") # condition -> integer row positions condition_to_indices = obs.groupby(condition_key, observed=True).indices # control ids -> integer row positions per condition (preserves order) control_to_indices = { cond: obs_index.get_indexer_for(ids) for cond, ids in control_cells_ids.items() } target_conditions = list(control_cells_ids.keys()) adata_results = [] results_df = [] # Condition loop starts here for selected_condition in tqdm( target_conditions, desc="Analyzing conditions", unit="cond" ): rows_cond = condition_to_indices.get( selected_condition, np.array([], dtype=int) ) rows_ctrl = control_to_indices.get(selected_condition, np.array([], dtype=int)) # Filter out any missing indices (-1) rows_ctrl = np.asarray(rows_ctrl, dtype=int) rows_ctrl = rows_ctrl[rows_ctrl >= 0] if len(rows_cond) < min_pert_cells or len(rows_ctrl) == 0: print(f"Insufficient cells for analysis of {selected_condition}") continue # Create condition and control data, then concatenate # Copy slices to avoid ImplicitModificationWarning when editing .obs adata_condition = adata[rows_cond].copy() adata_control = adata[rows_ctrl].copy() if len(adata_condition) != len(adata_control): logger.warning( f"Condition and control data for {selected_condition} have different lengths." ) if adata.isbacked: adata_condition = adata_condition.to_memory() adata_control = adata_control.to_memory() # Add comparison group label to each slice before concatenation adata_condition.obs["comparison_group"] = selected_condition adata_control.obs["comparison_group"] = "control" adata_merged = ad.concat( [adata_condition, adata_control], index_unique=None ).copy() # Normalize and filter sc.pp.filter_cells(adata_merged, min_genes=filter_min_genes) sc.pp.filter_genes(adata_merged, min_cells=filter_min_cells) sc.pp.normalize_total(adata_merged, target_sum=1e4) sc.pp.log1p(adata_merged) comparison_group_counts = adata_merged.obs["comparison_group"].value_counts() if len(comparison_group_counts) < 2 or comparison_group_counts.min() < 1: logger.warning( f"Insufficient filtered cells for analysis of {selected_condition}" ) continue # Run statistical test sc.tl.rank_genes_groups( adata_merged, groupby="comparison_group", reference="control", method=deg_test_name, key_added="dge_results", ) # Get results DataFrame results = sc.get.rank_genes_groups_df( adata_merged, group=selected_condition, key="dge_results" ) # Add condition name results[condition_key] = selected_condition # Option to remove zero expression genes if remove_avg_zeros: gene_ids = results["names"].values cond_view = adata_condition[:, gene_ids] ctrl_view = adata_control[:, gene_ids] cond_mean = cond_view.X.mean(axis=0) ctrl_mean = ctrl_view.X.mean(axis=0) # handle sparse vs dense cond_mean = ( cond_mean.A1 if hasattr(cond_mean, "A1") else np.asarray(cond_mean).ravel() ) ctrl_mean = ( ctrl_mean.A1 if hasattr(ctrl_mean, "A1") else np.asarray(ctrl_mean).ravel() ) indexes = np.where((cond_mean > 0) & (ctrl_mean > 0))[0] logger.info( f"remove_avg_zeros is True. Removing {len(results) - len(indexes)} genes with zero expression" ) results = results.iloc[indexes] results_df.append(results) if return_merged_adata: adata_results.append(adata_merged) if not results_df: return pd.DataFrame(), None results = pd.concat(results_df, ignore_index=True) del results_df # dge params captured from last run try: dge_params = adata_merged.uns["dge_results"]["params"] # type: ignore[name-defined] except Exception: dge_params = {} if return_merged_adata: adata_merged = ad.concat(adata_results, index_unique=None) del adata_results else: adata_merged = None # Standardize column names col_mapper = { "names": de_gene_col, "scores": "score", "logfoldchanges": "logfoldchange", "pvals": "pval", "pvals_adj": "pval_adj", } results = results.rename(columns=col_mapper) cols = [x for x in col_mapper.values() if x in results.columns] + [condition_key] results = results[cols] if store_dge_metadata: dge_params.update( { "remove_avg_zeros": remove_avg_zeros, "filter_min_cells": filter_min_cells, "filter_min_genes": filter_min_genes, "min_pert_cells": min_pert_cells, } ) results["dge_params"] = dge_params # NB: this is not tidy return results, adata_merged