Source code for czbenchmarks.datasets.utils

import os
import hydra
from hydra.utils import instantiate
from typing import List, Optional
import yaml
from omegaconf import OmegaConf
from ..constants import DATASETS_CACHE_PATH
from ..utils import initialize_hydra, download_file_from_remote
from .base import BaseDataset


[docs] def load_dataset( dataset_name: str, config_path: Optional[str] = None, ) -> BaseDataset: """ Download and instantiate a dataset using Hydra configuration. Args: dataset_name: Name of dataset as specified in config config_path: Optional path to config yaml file. If not provided, will use only the package's default config. Returns: BaseDataset: Instantiated dataset object """ initialize_hydra() # Load default config first and make it unstructured cfg = OmegaConf.create( OmegaConf.to_container(hydra.compose(config_name="datasets"), resolve=True) ) # If custom config provided, load and merge it if config_path is not None: # Expand user path (handles ~) config_path = os.path.expanduser(config_path) config_path = os.path.abspath(config_path) if not os.path.exists(config_path): raise FileNotFoundError(f"Custom config file not found: {config_path}") # Load custom config with open(config_path) as f: custom_cfg = OmegaConf.create(yaml.safe_load(f)) # Merge configs cfg = OmegaConf.merge(cfg, custom_cfg) if dataset_name not in cfg.datasets: raise ValueError(f"Dataset {dataset_name} not found in config") dataset_info = cfg.datasets[dataset_name] original_path = dataset_info.path is_s3_path = original_path.startswith("s3://") expanded_path = os.path.expanduser(original_path) if not is_s3_path: if not os.path.exists(expanded_path): raise FileNotFoundError(f"Local dataset file not found: {expanded_path}") else: # Setup cache path cache_path = os.path.expanduser(DATASETS_CACHE_PATH) os.makedirs(cache_path, exist_ok=True) cache_file = os.path.join(cache_path, f"{dataset_name}.h5ad") # Only download if file doesn't exist if not os.path.exists(cache_file): download_file_from_remote(original_path, cache_path, f"{dataset_name}.h5ad") # Update path to cached file dataset_info.path = cache_file # Instantiate the dataset using Hydra dataset = instantiate(dataset_info) dataset.path = os.path.expanduser(dataset.path) return dataset
[docs] def list_available_datasets() -> List[str]: """ Lists all available datasets defined in the datasets.yaml configuration file. Returns: list: A sorted list of dataset names available in the configuration. """ initialize_hydra() # Load the datasets configuration cfg = OmegaConf.to_container(hydra.compose(config_name="datasets"), resolve=True) # Extract dataset names dataset_names = list(cfg.get("datasets", {}).keys()) # Sort alphabetically for easier reading dataset_names.sort() return dataset_names
_DATASET_TO_DISPLAY = { "adamson_perturb": ("Adamson", ""), "norman_perturb": ("Norman", ""), "dixit_perturb": ("Dixit", ""), "replogle_k562_perturb": ("Replogle", "K562"), "replogle_rpe1_perturb": ("Replogle", "RPE1"), "human_spermatogenesis": ("Spermatogenesis", "Homo sapiens"), "mouse_spermatogenesis": ("Spermatogenesis", "Mus musculus"), "rhesus_macaque_spermatogenesis": ("Spermatogenesis", "Macaca mulatta"), "gorilla_spermatogenesis": ("Spermatogenesis", "Gorilla gorilla"), "chimpanzee_spermatogenesis": ("Spermatogenesis", "Pan troglodytes"), "marmoset_spermatogenesis": ("Spermatogenesis", "Callithrix jacchus"), "chicken_spermatogenesis": ("Spermatogenesis", "Gallus gallus"), "opossum_spermatogenesis": ("Spermatogenesis", "Monodelphis domestica"), "platypus_spermatogenesis": ("Spermatogenesis", "Ornithorhynchus anatinus"), "tsv2_bladder": ("Tabula Sapiens 2.0", "Bladder"), "tsv2_blood": ("Tabula Sapiens 2.0", "Blood"), "tsv2_bone_marrow": ("Tabula Sapiens 2.0", "Bone marrow"), "tsv2_ear": ("Tabula Sapiens 2.0", "Ear"), "tsv2_eye": ("Tabula Sapiens 2.0", "Eye"), "tsv2_fat": ("Tabula Sapiens 2.0", "Fat"), "tsv2_heart": ("Tabula Sapiens 2.0", "Heart"), "tsv2_large_intestine": ("Tabula Sapiens 2.0", "Large intestine"), "tsv2_liver": ("Tabula Sapiens 2.0", "Liver"), "tsv2_lung": ("Tabula Sapiens 2.0", "Lung"), "tsv2_lymph_node": ("Tabula Sapiens 2.0", "Lymph node"), "tsv2_mammary": ("Tabula Sapiens 2.0", "Mammary"), "tsv2_muscle": ("Tabula Sapiens 2.0", "Muscle"), "tsv2_ovary": ("Tabula Sapiens 2.0", "Ovary"), "tsv2_prostate": ("Tabula Sapiens 2.0", "Prostate"), "tsv2_salivary_gland": ("Tabula Sapiens 2.0", "Salivary gland"), "tsv2_skin": ("Tabula Sapiens 2.0", "Skin"), "tsv2_small_intestine": ("Tabula Sapiens 2.0", "Small intestine"), "tsv2_spleen": ("Tabula Sapiens 2.0", "Spleen"), "tsv2_stomach": ("Tabula Sapiens 2.0", "Stomach"), "tsv2_testis": ("Tabula Sapiens 2.0", "Testis"), "tsv2_thymus": ("Tabula Sapiens 2.0", "Thymus"), "tsv2_tongue": ("Tabula Sapiens 2.0", "Tongue"), "tsv2_trachea": ("Tabula Sapiens 2.0", "Trachea"), "tsv2_uterus": ("Tabula Sapiens 2.0", "Uterus"), "tsv2_vasculature": ("Tabula Sapiens 2.0", "Vasculature"), }
[docs] def dataset_to_display(dataset_name: str) -> tuple[str, str]: """try to map dataset names to more uniform, pretty strings""" try: return _DATASET_TO_DISPLAY[dataset_name] except KeyError: # e.g. "my_awesome_dataset" -> "My awesome dataset" parts = dataset_name.split("_") return (" ".join((parts[0].title(), *(part.lower() for part in parts[1:]))), "")