Source code for czbenchmarks.tasks.label_prediction
import logging
from typing import Set, List
import pandas as pd
import scipy as sp
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
accuracy_score,
f1_score,
make_scorer,
precision_score,
recall_score,
roc_auc_score,
)
from sklearn.model_selection import StratifiedKFold, cross_validate
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from ..models.types import ModelType
from ..datasets import BaseDataset, DataType
from ..metrics import metrics_registry
from ..metrics.types import MetricResult, MetricType
from .base import BaseTask
from .utils import filter_minimum_class
from .constants import RANDOM_SEED, N_FOLDS, MIN_CLASS_SIZE
logger = logging.getLogger(__name__)
[docs]
class MetadataLabelPredictionTask(BaseTask):
"""Task for predicting labels from embeddings using cross-validation.
Evaluates multiple classifiers (Logistic Regression, KNN) using k-fold
cross-validation. Reports standard classification metrics.
Args:
label_key: Key to access ground truth labels in metadata
n_folds: Number of cross-validation folds
random_seed: Random seed for reproducibility
min_class_size: Minimum samples required per class
"""
def __init__(
self,
label_key: str,
n_folds: int = N_FOLDS,
random_seed: int = RANDOM_SEED,
min_class_size: int = MIN_CLASS_SIZE,
):
self.label_key = label_key
self.n_folds = n_folds
self.random_seed = random_seed
self.min_class_size = min_class_size
logger.info(
"Initialized MetadataLabelPredictionTask with: "
f"label_key='{label_key}', n_folds={n_folds}, "
f"min_class_size={min_class_size}, "
)
@property
def display_name(self) -> str:
"""A pretty name to use when displaying task results"""
return "metadata label prediction"
@property
def required_inputs(self) -> Set[DataType]:
"""Required input data types.
Returns:
Set of required input DataTypes (metadata with labels)
"""
return {DataType.METADATA}
@property
def required_outputs(self) -> Set[DataType]:
"""Required output data types.
Returns:
required output types from models this task to run (embedding coordinates)
"""
return {DataType.EMBEDDING}
def _run_task(self, data: BaseDataset, model_type: ModelType):
"""Runs cross-validation prediction task.
Evaluates multiple classifiers using k-fold cross-validation on the
embedding data. Stores results for metric computation.
Args:
data: Dataset containing embedding and ground truth labels
"""
logger.info(f"Starting prediction task for label key: {self.label_key}")
# Get embedding and labels
embeddings = data.get_output(model_type, DataType.EMBEDDING)
labels = data.get_input(DataType.METADATA)[self.label_key]
logger.info(
f"Initial data shape: {embeddings.shape}, labels shape: {labels.shape}"
)
# Filter classes with minimum size requirement
embeddings, labels = filter_minimum_class(
embeddings, labels, min_class_size=self.min_class_size
)
logger.info(f"After filtering: {embeddings.shape} samples remaining")
# Determine scoring metrics based on number of classes
n_classes = len(labels.unique())
target_type = "binary" if n_classes == 2 else "macro"
logger.info(
f"Found {n_classes} classes, using {target_type} averaging for metrics"
)
scorers = {
"accuracy": make_scorer(accuracy_score),
"f1": make_scorer(f1_score, average=target_type),
"precision": make_scorer(precision_score, average=target_type),
"recall": make_scorer(recall_score, average=target_type),
"auroc": make_scorer(
roc_auc_score,
average="weighted",
multi_class="ovr",
response_method="predict_proba",
),
}
# Setup cross validation
skf = StratifiedKFold(
n_splits=self.n_folds, shuffle=True, random_state=self.random_seed
)
logger.info(
f"Using {self.n_folds}-fold cross validation with random_seed {self.random_seed}"
)
# Create classifiers
classifiers = {
"lr": Pipeline(
[("scaler", StandardScaler()), ("lr", LogisticRegression())]
),
"knn": Pipeline(
[("scaler", StandardScaler()), ("knn", KNeighborsClassifier())]
),
"rf": Pipeline(
[("rf", RandomForestClassifier(random_state=self.random_seed))]
),
}
logger.info(f"Created classifiers: {list(classifiers.keys())}")
# Store results and predictions
self.results = []
self.predictions = []
# Run cross validation for each classifier
labels = pd.Categorical(labels.astype(str))
for name, clf in classifiers.items():
logger.info(f"Running cross-validation for {name}...")
cv_results = cross_validate(
clf,
embeddings,
labels.codes,
cv=skf,
scoring=scorers,
return_train_score=False,
)
for fold in range(self.n_folds):
fold_results = {"classifier": name, "split": fold}
for metric in scorers.keys():
fold_results[metric] = cv_results[f"test_{metric}"][fold]
self.results.append(fold_results)
logger.debug(f"{name} fold {fold} results: {fold_results}")
logger.info("Completed cross-validation for all classifiers")
def _compute_metrics(self) -> List[MetricResult]:
"""Computes classification metrics across all folds.
Aggregates results from cross-validation and computes mean metrics
per classifier and overall.
Returns:
List of MetricResult objects containing mean metrics across all classifiers
and per-classifier metrics
"""
logger.info("Computing final metrics...")
results_df = pd.DataFrame(self.results)
metrics_list = []
# Calculate overall metrics across all classifiers
metrics_list.extend(
[
MetricResult(
metric_type=MetricType.MEAN_FOLD_ACCURACY,
value=metrics_registry.compute(
MetricType.MEAN_FOLD_ACCURACY, results_df=results_df
),
),
MetricResult(
metric_type=MetricType.MEAN_FOLD_F1_SCORE,
value=metrics_registry.compute(
MetricType.MEAN_FOLD_F1_SCORE, results_df=results_df
),
),
MetricResult(
metric_type=MetricType.MEAN_FOLD_PRECISION,
value=metrics_registry.compute(
MetricType.MEAN_FOLD_PRECISION, results_df=results_df
),
),
MetricResult(
metric_type=MetricType.MEAN_FOLD_RECALL,
value=metrics_registry.compute(
MetricType.MEAN_FOLD_RECALL, results_df=results_df
),
),
MetricResult(
metric_type=MetricType.MEAN_FOLD_AUROC,
value=metrics_registry.compute(
MetricType.MEAN_FOLD_AUROC, results_df=results_df
),
),
]
)
# Calculate per-classifier metrics
for clf in results_df["classifier"].unique():
metrics_list.extend(
[
MetricResult(
metric_type=MetricType.MEAN_FOLD_ACCURACY,
value=metrics_registry.compute(
MetricType.MEAN_FOLD_ACCURACY,
results_df=results_df,
classifier=clf,
),
params={"classifier": clf},
),
MetricResult(
metric_type=MetricType.MEAN_FOLD_F1_SCORE,
value=metrics_registry.compute(
MetricType.MEAN_FOLD_F1_SCORE,
results_df=results_df,
classifier=clf,
),
params={"classifier": clf},
),
MetricResult(
metric_type=MetricType.MEAN_FOLD_PRECISION,
value=metrics_registry.compute(
MetricType.MEAN_FOLD_PRECISION,
results_df=results_df,
classifier=clf,
),
params={"classifier": clf},
),
MetricResult(
metric_type=MetricType.MEAN_FOLD_RECALL,
value=metrics_registry.compute(
MetricType.MEAN_FOLD_RECALL,
results_df=results_df,
classifier=clf,
),
params={"classifier": clf},
),
MetricResult(
metric_type=MetricType.MEAN_FOLD_AUROC,
value=metrics_registry.compute(
MetricType.MEAN_FOLD_AUROC,
results_df=results_df,
classifier=clf,
),
params={"classifier": clf},
),
]
)
return metrics_list
[docs]
def set_baseline(self, data: BaseDataset):
"""Set a baseline embedding using raw gene expression.
Instead of using embeddings from a model, this method uses the raw gene
expression matrix as features for classification. This provides a baseline
performance to compare against model-generated embeddings for classification
tasks.
Args:
data: BaseDataset containing AnnData with gene expression and metadata
"""
# Get the AnnData object from the dataset
adata = data.get_input(DataType.ANNDATA)
# Extract gene expression matrix
X = adata.X
# Convert sparse matrix to dense if needed
if sp.sparse.issparse(X):
X = X.toarray()
# Use raw gene expression as the "embedding" for baseline classification
data.set_output(ModelType.BASELINE, DataType.EMBEDDING, X)