Source code for czbenchmarks.cli.utils

import functools
import importlib.metadata
import json
import logging
import subprocess
from pathlib import Path
from typing import Any, List

import click
import numpy as np
import pandas as pd
import tomli

from ..datasets import Dataset, load_dataset
from ..metrics.utils import aggregate_results
from ..tasks.types import CellRepresentation

log = logging.getLogger(__name__)

_REPO_PATH = Path(__file__).parent.parent.parent.parent


def _get_pyproject_version() -> str:
    """
    Make an attempt to get the version from pyproject.toml
    """
    pyproject_path = _REPO_PATH / "pyproject.toml"

    try:
        with open(pyproject_path, "rb") as f:
            pyproject = tomli.load(f)
        return str(pyproject["project"]["version"])
    except Exception:
        log.exception("Could not determine cz-benchmarks version from pyproject.toml")

    return "unknown"


def _get_git_commit(base_version: str) -> str:
    """
    Return '' if the repo is exactly at the tag matching `base_version`
    (which should be what's in the pyproject file, with NO 'v' prepended)
    or '+<short-sha>[.dirty]' if not, where '.dirty' is added when there
    are uncommitted changes
    """
    if not (_REPO_PATH / ".git").exists():
        return ""

    tag = "v" + base_version  # this is our convention
    try:
        tag_commit = subprocess.check_output(
            ["git", "-C", str(_REPO_PATH), "rev-list", "-n", "1", tag],
            stderr=subprocess.DEVNULL,
            text=True,
        ).strip()
    except subprocess.CalledProcessError:
        log.error("Could not find a commit hash for tag %r in git", tag)
        tag_commit = "error"

    try:
        commit = subprocess.check_output(
            ["git", "-C", str(_REPO_PATH), "rev-parse", "HEAD"],
            stderr=subprocess.DEVNULL,
            text=True,
        ).strip()
    except subprocess.CalledProcessError:
        log.error("Could not get current commit hash from git")
        commit = "unknown"

    try:
        is_dirty = (
            bool(  # the subprocess will return an empty string if the repo is clean
                subprocess.check_output(
                    ["git", "-C", str(_REPO_PATH), "status", "--porcelain"],
                    stderr=subprocess.DEVNULL,
                    text=True,
                ).strip()
            )
        )
    except subprocess.CalledProcessError:
        log.error("Could not get repo status from git")
        is_dirty = True

    if tag_commit == commit and not is_dirty:
        # if we're on the commit matching the version tag, then our version is simply the tag
        return ""
    else:
        # otherwise we want to add the commit hash and dirty status
        dirty_string = ".dirty" if is_dirty else ""
        return f"+{commit[:7]}{dirty_string}"


[docs] @functools.cache def get_version() -> str: """ Get the current version of the cz-benchmarks library """ try: version = importlib.metadata.version("cz-benchmarks") # yes, with the hyphen except importlib.metadata.PackageNotFoundError: log.debug( "Package `cz-benchmarks` is not installed: fetching version info from pyproject.toml" ) version = _get_pyproject_version() git_commit = _get_git_commit(version) return "v" + version + git_commit
[docs] def get_datasets(dataset_names: List[str]) -> List[Dataset]: """Loads a list of datasets by name.""" try: return [load_dataset(name) for name in dataset_names] except Exception as e: raise click.UsageError(f"Failed to load dataset: {e}")
[docs] def load_embedding(path: str) -> CellRepresentation: """Loads a model embedding from a file.""" try: if path.endswith(".npy"): return np.load(path) elif path.endswith(".csv"): return pd.read_csv(path, index_col=0).values else: raise NotImplementedError( "Only .npy and .csv embedding files are supported." ) except Exception as e: raise click.BadParameter(f"Could not load embedding from '{path}': {e}")
[docs] def write_results(results: List[Any], output_file: str | None): """Aggregates metrics and writes results to stdout or a file in JSON format.""" aggregated = aggregate_results(results) results_dict = [res.model_dump(mode="json") for res in aggregated] if output_file: with open(output_file, "w") as f: json.dump(results_dict, f, indent=2) click.echo(f"Results saved to {output_file}") else: click.echo("\n--- RESULTS ---") click.echo(json.dumps(results_dict, indent=2))
[docs] def mutually_exclusive(*options): def _callback(ctx, param, value): if value is not None: for other in options: if ctx.params.get(other) is not None: raise click.UsageError( f"Options --{param.name.replace('_', '-')} and --{other.replace('_', '-')} are mutually exclusive." ) return value return _callback