from enum import Enum
from typing import Any, Callable, Dict, Optional, Set
from pydantic import BaseModel, Field
[docs]
class MetricType(Enum):
"""Enumeration of all supported metric types.
Defines unique identifiers for evaluation metrics that can be computed.
Each metric type corresponds to a specific evaluation metric, and its value
is used as a string identifier in results dictionaries.
Examples:
- Clustering metrics: Adjusted Rand Index, Normalized Mutual Information
- Embedding quality metrics: Silhouette Score
- Integration metrics: Entropy Per Cell, Batch Silhouette
- Perturbation metrics: Mean Squared Error, Pearson Correlation
"""
# Clustering metrics
ADJUSTED_RAND_INDEX = "adjusted_rand_index"
NORMALIZED_MUTUAL_INFO = "normalized_mutual_info"
# Embedding quality metrics
SILHOUETTE_SCORE = "silhouette_score"
# Integration metrics
ENTROPY_PER_CELL = "entropy_per_cell"
BATCH_SILHOUETTE = "batch_silhouette"
# Regression metrics
MEAN_SQUARED_ERROR = "mean_squared_error"
PEARSON_CORRELATION = "PEARSON_CORRELATION"
# Classification metrics
ACCURACY = "accuracy"
ACCURACY_CALCULATION = "accuracy_calculation"
MEAN_FOLD_ACCURACY = "mean_fold_accuracy"
AUROC = "auroc"
MEAN_FOLD_AUROC = "mean_fold_auroc"
F1_SCORE = "f1"
F1_CALCULATION = "f1_calculation"
MEAN_FOLD_F1_SCORE = "mean_fold_f1"
JACCARD = "jaccard"
PRECISION = "precision"
PRECISION_CALCULATION = "precision_calculation"
MEAN_FOLD_PRECISION = "mean_fold_precision"
RECALL = "recall"
RECALL_CALCULATION = "recall_calculation"
MEAN_FOLD_RECALL = "mean_fold_recall"
SPEARMAN_CORRELATION_CALCULATION = "spearman_correlation_calculation"
# Sequential metrics
SEQUENTIAL_ALIGNMENT = "sequential_alignment"
[docs]
class MetricInfo(BaseModel):
"""Stores metadata about a metric.
Encapsulates information required for metric computation, including:
- The function implementing the metric.
- Required arguments for the metric function.
- Default parameters for the metric function.
- An optional description of the metric's purpose.
- Tags for grouping related metrics.
Attributes:
func (Callable): The function that computes the metric.
required_args (Set[str]): Names of required arguments for the metric function.
default_params (Dict[str, Any]): Default parameters for the metric function.
description (Optional[str]): Documentation string describing the metric.
tags (Set[str]): Tags for categorizing metrics.
"""
func: Callable
"""The function that computes the metric"""
required_args: Set[str]
"""Set of required argument names"""
default_params: Dict[str, Any]
"""Default parameters for the metric function"""
description: Optional[str] = None
"""Optional documentation string for custom metrics"""
tags: Set[str] = None
"""Set of tags for grouping related metrics"""
[docs]
class MetricRegistry:
"""Central registry for all available metrics.
Provides functionality for registering, validating, and computing metrics.
Each metric is associated with a unique `MetricType` identifier and metadata
stored in a `MetricInfo` object.
Features:
- Register new metrics with required arguments, default parameters, and tags.
- Compute metrics by passing required arguments and merging with defaults.
- Retrieve metadata about registered metrics.
- List available metrics, optionally filtered by tags.
Attributes:
_metrics (Dict[MetricType, MetricInfo]): Internal storage for registered metrics.
"""
def __init__(self):
self._metrics: Dict[MetricType, MetricInfo] = {}
[docs]
def register(
self,
metric_type: MetricType,
func: Callable,
required_args: Optional[Set[str]] = None,
default_params: Optional[Dict[str, Any]] = None,
description: str = "",
tags: Optional[Set[str]] = None,
) -> None:
"""Register a new metric in the registry.
Associates a metric type with its computation function, required arguments,
default parameters, and metadata. Registered metrics can later be computed
using the `compute` method.
Args:
metric_type (MetricType): Unique identifier for the metric.
func (Callable): Function that computes the metric.
required_args (Optional[Set[str]]): Names of required arguments for the metric function.
default_params (Optional[Dict[str, Any]]): Default parameters for the metric function.
description (str): Documentation string describing the metric's purpose.
tags (Optional[Set[str]]): Tags for categorizing the metric.
Raises:
TypeError: If `metric_type` is not an instance of `MetricType`.
"""
if not isinstance(metric_type, MetricType):
raise TypeError(
f"Invalid metric type: {metric_type}. Must be a MetricType enum."
)
self._metrics[metric_type] = MetricInfo(
func=func,
required_args=required_args or set(),
default_params=default_params or {},
description=description,
tags=tags or set(),
)
[docs]
def compute(self, metric_type: MetricType, **kwargs) -> float:
"""Compute a registered metric with the given parameters.
Validates required arguments and merges them with default parameters before
calling the metric's computation function.
Args:
metric_type (MetricType): Type of metric to compute.
**kwargs: Arguments to pass to the metric function.
Returns:
float: Computed metric value.
Raises:
ValueError: If the metric type is unknown or required arguments are missing.
"""
if metric_type not in self._metrics:
raise ValueError(f"Unknown metric type: {metric_type}")
metric_info = self._metrics[metric_type]
# Validate required arguments
missing_args = metric_info.required_args - set(kwargs.keys())
if missing_args:
raise ValueError(
f"Missing required arguments for {metric_type}: {missing_args}"
)
# Merge with defaults and compute
params = {**metric_info.default_params, **kwargs}
return metric_info.func(**params)
[docs]
def get_info(self, metric_type: MetricType) -> MetricInfo:
"""Get metadata about a metric.
Args:
metric_type: Type of metric
Returns:
MetricInfo object with metric metadata
Raises:
ValueError: If metric type unknown
"""
if metric_type not in self._metrics:
raise ValueError(f"Unknown metric type: {metric_type}")
return self._metrics[metric_type]
[docs]
def list_metrics(self, tags: Optional[Set[str]] = None) -> Set[MetricType]:
"""List available metrics, optionally filtered by tags.
Retrieves all registered metrics, or filters them based on the provided tags.
Args:
tags (Optional[Set[str]]): Tags to filter metrics. Only metrics with all
specified tags will be returned.
Returns:
Set[MetricType]: Set of matching metric types.
"""
if tags is None:
return set(self._metrics.keys())
return {
metric_type
for metric_type, info in self._metrics.items()
if tags.issubset(info.tags)
}
[docs]
class MetricResult(BaseModel):
"""Represents the result of a single metric computation.
Encapsulates the computed value, associated metric type, and any parameters
used during computation. Provides functionality for generating aggregation keys
to group similar metrics.
Attributes:
metric_type (MetricType): The type of metric computed.
value (float): The computed metric value.
params (Optional[Dict[str, Any]]): Parameters used during computation.
Methods:
aggregation_key: Generates a key based on the metric type and parameters
to aggregate similar metrics together.
"""
metric_type: MetricType
value: float
params: Optional[Dict[str, Any]] = Field(default_factory=dict)
@property
def aggregation_key(self) -> str:
"""return a key based on the metric type and params in order to aggregate the same metrics together"""
if self.params is None:
params = ""
else:
params = "_".join(
(f"{key}-{value}" for key, value in sorted(self.params.items()))
)
return f"{self.metric_type}({params})"
[docs]
class AggregatedMetricResult(BaseModel):
"""Represents the aggregated result of multiple metric computations.
Stores statistical information about a set of metric values, including the
mean, standard deviation, and raw values. Useful for summarizing metrics
computed across multiple runs or folds.
Attributes:
metric_type (MetricType): The type of metric being aggregated.
params (Dict[str, Any] | None): Parameters used during computation.
n_values (int): Number of values aggregated.
value (float): Mean value of the aggregated metrics.
value_std_dev (float | None): Standard deviation of the aggregated metrics.
values_raw (list[float]): Raw values of the metrics being aggregated.
"""
metric_type: MetricType
params: Dict[str, Any] | None = Field(default_factory=dict)
n_values: int
value: float
value_std_dev: float | None
values_raw: list[float]