"""
Implementation of metric functions and registration with the registry.
This file defines and registers various metrics with a global `MetricRegistry`.
Metrics are categorized into the following types:
- Clustering metrics (e.g., Adjusted Rand Index, Normalized Mutual Information)
- Embedding quality metrics (e.g., Silhouette Score)
- Integration metrics (e.g., Entropy Per Cell, Batch Silhouette)
- Perturbation metrics (e.g., Mean Squared Error, Pearson Correlation)
- Label prediction metrics (e.g., Mean Fold Accuracy, Mean Fold F1 Score)
Each metric is registered with:
- A unique `MetricType` identifier.
- The function implementing the metric.
- Required arguments for the metric function.
- A description of the metric's purpose.
- Tags categorizing the metric.
"""
import numpy as np
from scib_metrics import silhouette_batch, silhouette_label
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import (
accuracy_score,
adjusted_rand_score,
f1_score,
mean_squared_error,
normalized_mutual_info_score,
precision_score,
recall_score,
)
from .utils import (
compute_entropy_per_cell,
mean_fold_metric,
single_metric,
jaccard_score,
)
from .types import MetricRegistry, MetricType
from .utils import (
sequential_alignment,
)
[docs]
def spearman_correlation(a, b):
"""Wrapper for spearmanr that returns only the correlation coefficient."""
result = spearmanr(a, b)
value = result.statistic
return 0 if np.isnan(value) else value
[docs]
def precision_score_zero_division(y_true, y_pred, **kwargs):
"""Wrapper for precision_score with zero_division=0 to suppress warnings."""
return precision_score(y_true, y_pred, zero_division=0, **kwargs)
[docs]
def recall_score_zero_division(y_true, y_pred, **kwargs):
"""Wrapper for recall_score with zero_division=0 to suppress warnings."""
return recall_score(y_true, y_pred, zero_division=0, **kwargs)
[docs]
def f1_score_zero_division(y_true, y_pred, **kwargs):
"""Wrapper for f1_score with zero_division=0 to suppress warnings."""
return f1_score(y_true, y_pred, zero_division=0, **kwargs)
# Create the global metric registry
metrics_registry = MetricRegistry()
# Register clustering metrics
metrics_registry.register(
MetricType.ADJUSTED_RAND_INDEX,
func=adjusted_rand_score,
required_args={"labels_true", "labels_pred"},
description="Adjusted Rand index between two clusterings",
tags={"clustering"},
)
metrics_registry.register(
MetricType.NORMALIZED_MUTUAL_INFO,
func=normalized_mutual_info_score,
required_args={"labels_true", "labels_pred"},
description="Normalized mutual information between two clusterings",
tags={"clustering"},
)
# Register embedding quality metrics
metrics_registry.register(
MetricType.SILHOUETTE_SCORE,
func=silhouette_label,
required_args={"X", "labels"},
description="Silhouette score for clustering evaluation",
tags={"embedding"},
)
# Register integration metrics
metrics_registry.register(
MetricType.ENTROPY_PER_CELL,
func=compute_entropy_per_cell,
required_args={"X", "labels"},
description=(
"Computes entropy of batch labels in local neighborhoods. Higher values indicate better batch mixing."
),
tags={"integration"},
)
metrics_registry.register(
MetricType.BATCH_SILHOUETTE,
func=silhouette_batch,
required_args={"X", "labels", "batch"},
description=(
"Batch-aware silhouette score that measures how well cells cluster across batches."
),
tags={"integration"},
)
# Perturbation metrics
metrics_registry.register(
MetricType.MEAN_SQUARED_ERROR,
func=mean_squared_error,
required_args={"y_true", "y_pred"},
description="Mean squared error between true and predicted values",
tags={"perturbation"},
)
# classification metrics
metrics_registry.register(
MetricType.ACCURACY,
func=single_metric,
required_args={"results_df", "metric"},
default_params={"metric": "accuracy"},
tags={"label_prediction"},
)
metrics_registry.register(
MetricType.ACCURACY_CALCULATION,
func=accuracy_score,
required_args={"y_true", "y_pred"},
description="Accuracy between true and predicted values",
tags={"label_prediction", "perturbation"},
)
metrics_registry.register(
MetricType.MEAN_FOLD_ACCURACY,
func=mean_fold_metric,
required_args={"results_df"},
default_params={"metric": "accuracy", "classifier": None},
tags={
"label_prediction",
},
)
metrics_registry.register(
MetricType.AUROC,
func=single_metric,
required_args={"results_df", "metric"},
default_params={"metric": "auroc"},
tags={"label_prediction"},
)
metrics_registry.register(
MetricType.MEAN_FOLD_AUROC,
func=mean_fold_metric,
required_args={"results_df"},
default_params={"metric": "auroc", "classifier": None},
tags={"label_prediction"},
)
metrics_registry.register(
MetricType.F1_SCORE,
func=single_metric,
required_args={"results_df", "metric"},
default_params={"metric": "f1"},
tags={"label_prediction"},
)
metrics_registry.register(
MetricType.F1_CALCULATION,
func=f1_score_zero_division,
required_args={"y_true", "y_pred"},
description="F1 score between true and predicted values",
tags={"label_prediction", "perturbation"},
)
metrics_registry.register(
MetricType.MEAN_FOLD_F1_SCORE,
func=mean_fold_metric,
required_args={"results_df"},
default_params={"metric": "f1", "classifier": None},
tags={"label_prediction"},
)
metrics_registry.register(
MetricType.JACCARD,
func=jaccard_score,
required_args={"y_true", "y_pred"},
description="Jaccard similarity between true and predicted values",
tags={"perturbation"},
)
metrics_registry.register(
MetricType.PEARSON_CORRELATION,
func=pearsonr,
required_args={"x", "y"},
description="Pearson correlation between true and predicted values",
tags={"perturbation"},
)
metrics_registry.register(
MetricType.PRECISION,
func=single_metric,
required_args={"results_df", "metric"},
default_params={"metric": "precision"},
tags={"label_prediction"},
)
metrics_registry.register(
MetricType.PRECISION_CALCULATION,
func=precision_score_zero_division,
required_args={"y_true", "y_pred"},
description="Precision between true and predicted values",
tags={"label_prediction", "perturbation"},
)
metrics_registry.register(
MetricType.MEAN_FOLD_PRECISION,
func=mean_fold_metric,
required_args={"results_df"},
default_params={"metric": "precision", "classifier": None},
tags={"label_prediction"},
)
metrics_registry.register(
MetricType.RECALL,
func=single_metric,
required_args={"results_df", "metric"},
default_params={"metric": "recall"},
tags={"label_prediction"},
)
metrics_registry.register(
MetricType.RECALL_CALCULATION,
func=recall_score_zero_division,
required_args={"y_true", "y_pred"},
description="Recall between true and predicted values",
tags={"label_prediction", "perturbation"},
)
metrics_registry.register(
MetricType.MEAN_FOLD_RECALL,
func=mean_fold_metric,
required_args={"results_df"},
default_params={"metric": "recall", "classifier": None},
tags={"label_prediction"},
)
metrics_registry.register(
MetricType.SEQUENTIAL_ALIGNMENT,
func=sequential_alignment,
required_args={"X", "labels"},
description="Sequential alignment score measuring consistency in embeddings",
tags={"sequential"},
)
metrics_registry.register(
MetricType.SPEARMAN_CORRELATION_CALCULATION,
func=spearman_correlation,
required_args={"a", "b"},
description="Spearman correlation between true and predicted values",
tags={"label_prediction", "perturbation"},
)