Source code for czbenchmarks.datasets.utils

import os
from typing import Dict, Optional, Any
from pathlib import Path
from urllib.parse import urlparse
import logging

import hydra
from hydra.utils import instantiate
from omegaconf import OmegaConf

from czbenchmarks.datasets.dataset import Dataset
from czbenchmarks.file_utils import download_file_from_remote
from czbenchmarks.utils import initialize_hydra, load_custom_config

logger = logging.getLogger(__name__)


[docs] def list_available_datasets() -> Dict[str, Dict[str, str]]: """ Return a sorted list of all dataset names defined in the `datasets.yaml` Hydra configuration. Returns: List[str]: Alphabetically sorted list of available dataset names. Notes: - Loads configuration using Hydra. - Extracts dataset names from the `datasets` section of the configuration. - Sorts the dataset names alphabetically for easier readability. """ initialize_hydra() # Load the datasets configuration cfg = OmegaConf.to_container(hydra.compose(config_name="datasets"), resolve=True) # Extract dataset names datasets = { name: { "organism": str(dataset_info.get("organism", "Unknown")), "url": dataset_info.get("path", "Unknown"), } for name, dataset_info in cfg.get("datasets", {}).items() } # Sort alphabetically for easier reading datasets = dict(sorted(datasets.items())) return datasets
[docs] def load_dataset( dataset_name: str, ) -> Dataset: """ Load, download (if needed), and instantiate a dataset using Hydra configuration. Args: dataset_name (str): Name of the dataset as specified in the configuration. Returns: Dataset: Instantiated dataset object with data loaded. Raises: ValueError: If the specified dataset is not found in the configuration. Notes: - Uses Hydra for instantiation and configuration management. - Downloads dataset file if a remote path is specified using `download_file_from_remote`. - The returned dataset object is an instance of the `Dataset` class or its subclass. """ initialize_hydra() cfg = hydra.compose(config_name="datasets") if dataset_name not in cfg.datasets: raise ValueError(f"Dataset {dataset_name} not found in config") dataset_info = cfg.datasets[dataset_name] # Handle local caching and remote downloading dataset_info["path"] = download_file_from_remote(dataset_info["path"]) # Instantiate the dataset using Hydra dataset = instantiate(dataset_info) # Load the dataset into memory dataset.load_data() return dataset
[docs] def load_custom_dataset( dataset_name: str, custom_dataset_config_path: Optional[str] = None, custom_dataset_kwargs: Optional[Dict[str, Any]] = None, cache_dir: Optional[str] = None, ) -> Dataset: """ Instantiate a dataset with a custom configuration. This can include but is not limited to a local path for a custom dataset file and/or a dictionary of custom parameters to update the default configuration. If the dataset name does not exist in the default config, this function will add the dataset to the configuration. Args: dataset_name: The name of the dataset, either custom or from the config custom_dataset_config_path: Optional path to a YAML file containing a custom configuration that can be used to update the existing default configuration. custom_dataset_kwargs: Custom configuration dictionary to update the default configuration of the dataset class. cache_dir: Optional directory to cache the dataset file. If not provided, the global cache manager directory will be used. Returns: Instantiated dataset object with data loaded. Example: ```python from czbenchmarks.datasets.types import Organism from czbenchmarks.datasets.utils import load_custom_dataset custom_dataset_config_path = "/path/to/new_dataset.yaml" my_dataset_name = "my_dataset" custom_dataset_kwargs = { "organism": Organism.HUMAN, "path": "example-small.h5ad", } dataset = load_custom_dataset( dataset_name=my_dataset_name, custom_dataset_config_path=custom_dataset_config_path, custom_dataset_kwargs=custom_dataset_kwargs ) ``` """ custom_cfg = load_custom_config( item_name=dataset_name, config_name="datasets", custom_config_path=custom_dataset_config_path, class_update_kwargs=custom_dataset_kwargs, ) if "path" not in custom_cfg: raise ValueError( f"Path required but not found in resolved configuration: {custom_cfg}" ) path = custom_cfg["path"] protocol = urlparse(str(path)).scheme if protocol: custom_cfg["path"] = download_file_from_remote(path, cache_dir=cache_dir) else: resolved_path = Path(path).expanduser().resolve() resolved_path = str(resolved_path) if not os.path.exists(resolved_path): raise FileNotFoundError( f"Local dataset file not found at path: {resolved_path}" ) logger.info(f"Local dataset file found: {resolved_path}") custom_cfg["path"] = resolved_path dataset = instantiate(custom_cfg) dataset.load_data() return dataset