Source code for czbenchmarks.tasks.utils

import logging
from typing import List

import numpy as np
import pandas as pd
import scanpy as sc
from anndata import AnnData
from .constants import RANDOM_SEED, FLAVOR, KEY_ADDED, OBSM_KEY

logger = logging.getLogger(__name__)

MULTI_DATASET_TASK_NAMES = frozenset(["cross_species"])

TASK_NAMES = frozenset(
    {
        "clustering",
        "embedding",
        "label_prediction",
        "integration",
        "perturbation",
    }.union(MULTI_DATASET_TASK_NAMES)
)


# TODO: Later we can add cluster parameters as kwargs here and add them
# to the task config
[docs] def cluster_embedding( adata: AnnData, obsm_key: str = OBSM_KEY, random_seed: int = RANDOM_SEED, n_iterations: int = 2, flavor: str = FLAVOR, key_added: str = KEY_ADDED, ) -> List[int]: """Cluster cells in embedding space using the Leiden algorithm. Computes nearest neighbors in the embedding space and runs the Leiden community detection algorithm to identify clusters. Args: adata: AnnData object containing the embedding obsm_key: Key in adata.obsm containing the embedding coordinates random_seed: Random seed for reproducibility n_iterations: Number of iterations for the Leiden algorithm flavor: Flavor of the Leiden algorithm key_added: Key in adata.obs to store the cluster assignments Returns: List of cluster assignments as integers """ sc.pp.neighbors(adata, use_rep=obsm_key, random_state=random_seed) sc.tl.leiden( adata, key_added=key_added, flavor=flavor, n_iterations=n_iterations, random_state=random_seed, ) return list(adata.obs["leiden"])
[docs] def filter_minimum_class( features: np.ndarray, labels: np.ndarray | pd.Series, min_class_size: int = 10, ) -> tuple[np.ndarray, np.ndarray | pd.Series]: """Filter data to remove classes with too few samples. Removes classes that have fewer samples than the minimum threshold. Useful for ensuring enough samples per class for ML tasks. Args: features: Feature matrix of shape (n_samples, n_features) labels: Labels array of shape (n_samples,) min_class_size: Minimum number of samples required per class Returns: Tuple containing: - Filtered feature matrix - Filtered labels as categorical data """ label_name = labels.name if hasattr(labels, "name") else "unknown" logger.info(f"Label composition ({label_name}):") class_counts = pd.Series(labels).value_counts() logger.info(f"Total classes before filtering: {len(class_counts)}") filtered_counts = class_counts[class_counts >= min_class_size] logger.info( f"Total classes after filtering " f"(min_class_size={min_class_size}): {len(filtered_counts)}" ) labels = pd.Series(labels) if isinstance(labels, np.ndarray) else labels class_counts = labels.value_counts() valid_classes = class_counts[class_counts >= min_class_size].index valid_indices = labels.isin(valid_classes) features_filtered = features[valid_indices] labels_filtered = labels[valid_indices] return features_filtered, pd.Categorical(labels_filtered)
[docs] def run_standard_scrna_workflow( adata: AnnData, n_top_genes: int = 3000, n_pcs: int = 50, random_state: int = 42 ) -> AnnData: """Run a standard preprocessing workflow for single-cell RNA-seq data. This function performs common preprocessing steps for scRNA-seq analysis: 1. Normalization of counts per cell 2. Log transformation 3. Identification of highly variable genes 4. Subsetting to highly variable genes 5. Principal component analysis Args: adata: AnnData object containing the raw count data n_top_genes: Number of highly variable genes to select n_pcs: Number of principal components to compute random_state: Random seed for reproducibility """ adata = adata.copy() # Standard preprocessing steps for single-cell data sc.pp.normalize_total(adata) # Normalize counts per cell sc.pp.log1p(adata) # Log-transform the data # Identify highly variable genes using Seurat method sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes) # Subset to only highly variable genes to reduce noise adata = adata[:, adata.var["highly_variable"]].copy() # Run PCA for dimensionality reduction sc.pp.pca(adata, n_comps=n_pcs, random_state=random_state) return adata