import os
import hydra
from hydra.utils import instantiate
from typing import List, Optional
import yaml
from omegaconf import OmegaConf
from ..constants import DATASETS_CACHE_PATH
from ..utils import initialize_hydra, download_file_from_remote
from .base import BaseDataset
[docs]
def load_dataset(
dataset_name: str,
config_path: Optional[str] = None,
) -> BaseDataset:
"""
Download and instantiate a dataset using Hydra configuration.
Args:
dataset_name: Name of dataset as specified in config
config_path: Optional path to config yaml file. If not provided,
will use only the package's default config.
Returns:
BaseDataset: Instantiated dataset object
"""
initialize_hydra()
# Load default config first and make it unstructured
cfg = OmegaConf.create(
OmegaConf.to_container(hydra.compose(config_name="datasets"), resolve=True)
)
# If custom config provided, load and merge it
if config_path is not None:
# Expand user path (handles ~)
config_path = os.path.expanduser(config_path)
config_path = os.path.abspath(config_path)
if not os.path.exists(config_path):
raise FileNotFoundError(f"Custom config file not found: {config_path}")
# Load custom config
with open(config_path) as f:
custom_cfg = OmegaConf.create(yaml.safe_load(f))
# Merge configs
cfg = OmegaConf.merge(cfg, custom_cfg)
if dataset_name not in cfg.datasets:
raise ValueError(f"Dataset {dataset_name} not found in config")
dataset_info = cfg.datasets[dataset_name]
original_path = dataset_info.path
is_s3_path = original_path.startswith("s3://")
expanded_path = os.path.expanduser(original_path)
if not is_s3_path:
if not os.path.exists(expanded_path):
raise FileNotFoundError(f"Local dataset file not found: {expanded_path}")
else:
# Setup cache path
cache_path = os.path.expanduser(DATASETS_CACHE_PATH)
os.makedirs(cache_path, exist_ok=True)
cache_file = os.path.join(cache_path, f"{dataset_name}.h5ad")
# Only download if file doesn't exist
if not os.path.exists(cache_file):
download_file_from_remote(original_path, cache_path, f"{dataset_name}.h5ad")
# Update path to cached file
dataset_info.path = cache_file
# Instantiate the dataset using Hydra
dataset = instantiate(dataset_info)
dataset.path = os.path.expanduser(dataset.path)
return dataset
[docs]
def list_available_datasets() -> List[str]:
"""
Lists all available datasets defined in the datasets.yaml configuration file.
Returns:
list: A sorted list of dataset names available in the configuration.
"""
initialize_hydra()
# Load the datasets configuration
cfg = OmegaConf.to_container(hydra.compose(config_name="datasets"), resolve=True)
# Extract dataset names
dataset_names = list(cfg.get("datasets", {}).keys())
# Sort alphabetically for easier reading
dataset_names.sort()
return dataset_names
_DATASET_TO_DISPLAY_NAME = {
"adamson_perturb": "Adamson",
"norman_perturb": "Norman",
"dixit_perturb": "Dixit",
"replogle_k562_perturb": "Replogle K562",
"replogle_rpe1_perturb": "Replogle RPE1",
"human_spermatogenesis": "Spermatogenesis - Homo sapiens",
"mouse_spermatogenesis": "Spermatogenesis - Mus musculus",
"rhesus_macaque_spermatogenesis": "Spermatogenesis - Macaca mulatta",
"gorilla_spermatogenesis": "Spermatogenesis - Gorilla gorilla",
"chimpanzee_spermatogenesis": "Spermatogenesis - Pan troglodytes",
"marmoset_spermatogenesis": "Spermatogenesis - Callithrix jacchus",
"chicken_spermatogenesis": "Spermatogenesis - Gallus gallus",
"opossum_spermatogenesis": "Spermatogenesis - Monodelphis domestica",
"platypus_spermatogenesis": "Spermatogenesis - Ornithorhynchus anatinus",
"tsv2_bladder": "Tabula Sapiens 2.0 - Bladder",
"tsv2_blood": "Tabula Sapiens 2.0 - Blood",
"tsv2_bone_marrow": "Tabula Sapiens 2.0 - Bone marrow",
"tsv2_ear": "Tabula Sapiens 2.0 - Ear",
"tsv2_eye": "Tabula Sapiens 2.0 - Eye",
"tsv2_fat": "Tabula Sapiens 2.0 - Fat",
"tsv2_heart": "Tabula Sapiens 2.0 - Heart",
"tsv2_large_intestine": "Tabula Sapiens 2.0 - Large intestine",
"tsv2_liver": "Tabula Sapiens 2.0 - Liver",
"tsv2_lung": "Tabula Sapiens 2.0 - Lung",
"tsv2_lymph_node": "Tabula Sapiens 2.0 - Lymph node",
"tsv2_mammary": "Tabula Sapiens 2.0 - Mammary",
"tsv2_muscle": "Tabula Sapiens 2.0 - Muscle",
"tsv2_ovary": "Tabula Sapiens 2.0 - Ovary",
"tsv2_prostate": "Tabula Sapiens 2.0 - Prostate",
"tsv2_salivary_gland": "Tabula Sapiens 2.0 - Salivary gland",
"tsv2_skin": "Tabula Sapiens 2.0 - Skin",
"tsv2_small_intestine": "Tabula Sapiens 2.0 - Small intestine",
"tsv2_spleen": "Tabula Sapiens 2.0 - Spleen",
"tsv2_stomach": "Tabula Sapiens 2.0 - Stomach",
"tsv2_testis": "Tabula Sapiens 2.0 - Testis",
"tsv2_thymus": "Tabula Sapiens 2.0 - Thymus",
"tsv2_tongue": "Tabula Sapiens 2.0 - Tongue",
"tsv2_trachea": "Tabula Sapiens 2.0 - Trachea",
"tsv2_uterus": "Tabula Sapiens 2.0 - Uterus",
"tsv2_vasculature": "Tabula Sapiens 2.0 - Vasculature",
}
[docs]
def dataset_to_display_name(dataset_name: str) -> str:
"""try to map dataset names to more uniform, pretty strings"""
try:
return _DATASET_TO_DISPLAY_NAME[dataset_name]
except KeyError:
# e.g. "my_awesome_dataset" -> "My awesome dataset"
parts = dataset_name.split("_")
return " ".join((parts[0].title(), *(part.lower() for part in parts[1:])))