import argparse
import itertools
import json
import logging
import os
import sys
import yaml
from collections import defaultdict
from collections.abc import Mapping
from copy import deepcopy
from datetime import datetime, timezone
from pathlib import Path
from pydantic import BaseModel, computed_field
from secrets import token_hex
from typing import Any, Generic, TypeVar
from czbenchmarks import runner
from czbenchmarks.cli import cli
from czbenchmarks.constants import PROCESSED_DATASETS_CACHE_PATH
from czbenchmarks.datasets import utils as dataset_utils
from czbenchmarks.datasets.base import BaseDataset
from czbenchmarks import exceptions
from czbenchmarks.metrics.types import MetricResult
from czbenchmarks.models import utils as model_utils
from czbenchmarks.models.types import ModelType
from czbenchmarks.tasks import utils as task_utils
from czbenchmarks.tasks.base import BaseTask
from czbenchmarks.tasks.clustering import ClusteringTask
from czbenchmarks.tasks.embedding import EmbeddingTask
from czbenchmarks.tasks.integration import BatchIntegrationTask
from czbenchmarks.tasks.label_prediction import MetadataLabelPredictionTask
from czbenchmarks.tasks.single_cell.cross_species import CrossSpeciesIntegrationTask
from czbenchmarks.tasks.single_cell.perturbation import PerturbationTask
from czbenchmarks import utils
log = logging.getLogger(__name__)
VALID_OUTPUT_FORMATS = ["json", "yaml"]
DEFAULT_OUTPUT_FORMAT = "json"
TaskType = TypeVar("TaskType", bound=BaseTask)
ModelArgsDict = dict[str, str | int] # Arguments passed to model inference
RuntimeMetricsDict = dict[
str, str | int | float
] # runtime metrics like elapsed time or CPU count, not implemented yet
[docs]
class ModelArgs(BaseModel):
name: str # Upper-case model name e.g. SCVI
args: dict[str, list[str | int]] # Args forwarded to the model container
[docs]
class TaskArgs(BaseModel, Generic[TaskType]):
model_config = {"arbitrary_types_allowed": True} # Required to support TaskType
name: str # Lower-case task name e.g. embedding
task: TaskType
set_baseline: bool
baseline_args: dict[str, Any]
[docs]
class TaskResult(BaseModel):
task_name: str
task_name_display: str
model_type: ModelType
dataset_names: list[str]
dataset_names_display: list[str]
model_args: ModelArgsDict
metrics: list[MetricResult]
runtime_metrics: RuntimeMetricsDict = {} # not implementing any of these for now
@computed_field
@property
def model_name_display(self) -> str:
return model_utils.model_to_display_name(self.model_type, self.model_args)
[docs]
class CacheOptions(BaseModel):
download_embeddings: bool
upload_embeddings: bool
upload_results: bool
remote_cache_url: str
[docs]
@classmethod
def from_args(cls, args: argparse.Namespace) -> "CacheOptions":
remote_cache_url = args.remote_cache_url or ""
return cls(
remote_cache_url=remote_cache_url,
download_embeddings=bool(remote_cache_url)
and args.remote_cache_download_embeddings,
upload_embeddings=bool(remote_cache_url)
and args.remote_cache_upload_embeddings,
upload_results=bool(remote_cache_url) and args.remote_cache_upload_results,
)
[docs]
def add_arguments(parser: argparse.ArgumentParser) -> None:
"""
Add run command arguments to the parser.
"""
parser.add_argument(
"--models",
"-m",
nargs="+",
choices=model_utils.list_available_models(),
help="One or more model names (from models.yaml).",
)
parser.add_argument(
"--datasets",
"-d",
nargs="+",
choices=dataset_utils.list_available_datasets(),
help="One or more dataset names (from datasets.yaml).",
)
parser.add_argument(
"--tasks",
"-t",
nargs="+",
choices=task_utils.TASK_NAMES,
help="One or more tasks to run.",
)
parser.add_argument(
"--output-format",
"-fmt",
choices=VALID_OUTPUT_FORMATS,
default="yaml",
help="Output format for results (ignored if --output-file specifies a valid file extension)",
)
parser.add_argument(
"--output-file",
"-o",
help="Path to file or directory to save results (default is stdout)",
)
parser.add_argument(
"--remote-cache-url",
help=(
"AWS S3 URL prefix for caching embeddings and storing output "
"(example: s3://cz-benchmarks-example/). Files will be stored "
"underneath the current --version number. This alone will not "
"trigger any caching behavior unless one or more of the "
"--remote-cache-download-embeddings, --remote-cache-upload-embeddings "
"or --remote-cache-upload-results flags are specified."
),
)
parser.add_argument(
"--remote-cache-download-embeddings",
action="store_true",
help=(
"If specified, download embeddings from the remote cache to "
"PROCESSED_DATASETS_CACHE_PATH if local versions do not exist "
"or are older than those in the remote cache. Only embeddings "
"matching the current version will be downloaded."
),
default=False,
)
parser.add_argument(
"--remote-cache-upload-embeddings",
action="store_true",
help=(
"Upload any processed embeddings produced to the remote cache, overwriting "
"any that may already exist there for the current version. They will be "
"stored under s3://<remote_cache_url>/<version>/processed-datasets/*.dill"
),
default=False,
)
parser.add_argument(
"--remote-cache-upload-results",
action="store_true",
help=(
"Upload the results to the remote cache. This allows results "
"to be shared across instances. They will be stored under "
"s3://<remote_cache_url>/<version>/results/<timestamp>-<random_hex>.json"
),
default=False,
)
# Extra arguments for geneformer model
parser.add_argument(
"--geneformer-model-variant",
nargs="+",
help="Variant of the geneformer model to use (see docker/geneformer/config.yaml)",
)
# Extra arguments for scgenept model
parser.add_argument(
"--scgenept-model-variant",
nargs="+",
help="Variant of the scgenept model to use (see docker/scgenept/config.yaml)",
)
parser.add_argument(
"--scgenept-gene-pert",
nargs="+",
help="Gene perturbation to use for scgenept model",
)
parser.add_argument(
"--scgenept-dataset-name",
nargs="+",
help="Dataset name to use for scgenept model",
)
parser.add_argument(
"--scgenept-chunk-size",
type=int,
nargs="+",
help="Chunk size to use for scgenept model",
)
# Extra arguments for scgpt model
parser.add_argument(
"--scgpt-model-variant",
nargs="+",
help="Variant of the scgpt model to use (see docker/scgpt/config.yaml)",
)
# Extra arguments for scvi model
parser.add_argument(
"--scvi-model-variant",
nargs="+",
help="Variant of the scvi model to use (see docker/scvi/config.yaml)",
)
# Extra arguments for uce model
parser.add_argument(
"--uce-model-variant",
nargs="+",
help="Variant of the uce model to use (see docker/uce/config.yaml)",
)
# Extra arguments for transcriptformer model
parser.add_argument(
"--transcriptformer-model-variant",
nargs="+",
choices=["tf-sapiens", "tf-exemplar", "tf-metazoa"],
help="Variant of the transcriptformer model to use (tf-sapiens, tf-exemplar, tf-metazoa)",
)
parser.add_argument(
"--transcriptformer-batch-size",
type=int,
nargs="+",
help="Batch size for transcriptformer model inference",
)
# Extra arguments for AIDO model
parser.add_argument(
"--aido-model-variant",
nargs="*",
choices=["aido_cell_3m", "aido_cell_10m", "aido_cell_100m"],
default="aido_cell_3m",
help="Variant of the aido model to use. Default is aido_cell_3m",
)
parser.add_argument(
"--aido-batch-size",
type=int,
nargs="*",
help="Batch size for AIDO model inference (optional)",
)
# Extra arguments for clustering task
parser.add_argument(
"--clustering-task-label-key",
help="Label key to use for clustering task",
)
parser.add_argument(
"--clustering-task-set-baseline",
action="store_true",
help="Preprocess dataset and set PCA embedding as the BASELINE model output in the dataset",
)
# Extra arguments for embedding task
parser.add_argument(
"--embedding-task-label-key",
help="Label key to use for embedding task",
)
parser.add_argument(
"--embedding-task-set-baseline",
action="store_true",
help="Preprocess dataset and set PCA embedding as the BASELINE model output in the dataset",
)
# Extra arguments for label prediction task
parser.add_argument(
"--label-prediction-task-label-key",
help="Label key to use for label prediction task",
)
parser.add_argument(
"--label-prediction-task-set-baseline",
action="store_true",
help="Preprocess dataset and set PCA embedding as the BASELINE model output in the dataset",
)
parser.add_argument(
"--label-prediction-task-n-folds",
type=int,
help="Number of cross-validation folds (optional)",
)
parser.add_argument(
"--label-prediction-task-seed",
type=int,
help="Random seed for reproducibility (optional)",
)
parser.add_argument(
"--label-prediction-task-min-class-size",
type=int,
help="Minimum samples required per class (optional)",
)
# Extra arguments for integration task
parser.add_argument(
"--integration-task-label-key",
help="Label key to use for integration task",
)
parser.add_argument(
"--integration-task-set-baseline",
action="store_true",
help="Use raw gene expression matrix as features for classification (instead of embeddings)",
)
parser.add_argument(
"--integration-task-batch-key",
help="Key to access batch labels in metadata",
)
# Extra arguments for cross species integration task
parser.add_argument(
"--cross-species-task-label-key",
help="Label key to use for cross species integration task",
)
# Extra arguments for perturbation task
parser.add_argument(
"--perturbation-task-set-baseline",
action="store_true",
help="Use mean and median predictions as the BASELINE model output in the dataset",
)
parser.add_argument(
"--perturbation-task-baseline-gene-pert",
type=str,
help="Gene perturbation to use for baseline",
)
# Advanced feature: define multiple batches of jobs using JSON
parser.add_argument(
"--batch-json",
"-b",
nargs="+",
default=[""],
help='Override CLI arguments from the given JSON, e.g. \'{"output_file": "..."}\'. Can be set multiple times to run complex "batch" jobs.',
)
[docs]
def main(parsed_args: argparse.Namespace) -> None:
"""
Execute a series of tasks using multiple models on a collection of datasets.
This function handles the benchmarking process by iterating over the specified datasets,
running inference with the provided models to generate results, and running the tasks to evaluate
the generated outputs.
"""
task_results: list[TaskResult] = []
batch_args = parse_batch_json(parsed_args.batch_json)
cache_options = CacheOptions.from_args(parsed_args)
for batch_idx, batch_dict in enumerate(batch_args):
log.info(f"Starting batch {batch_idx + 1}/{len(parsed_args.batch_json)}")
args = deepcopy(parsed_args)
for batch_key, batch_val in batch_dict.items():
setattr(args, batch_key, batch_val)
# Collect all the arguments that we'll need to pass directly to each model
model_args: list[ModelArgs] = []
for model_name in args.models or []:
model_args.append(parse_model_args(model_name.lower(), args))
# Collect all the task-related arguments
task_args: list[TaskArgs] = []
if "clustering" in args.tasks:
task_args.append(parse_task_args("clustering", ClusteringTask, args))
if "embedding" in args.tasks:
task_args.append(parse_task_args("embedding", EmbeddingTask, args))
if "label_prediction" in args.tasks:
task_args.append(
parse_task_args("label_prediction", MetadataLabelPredictionTask, args)
)
if "integration" in args.tasks:
task_args.append(parse_task_args("integration", BatchIntegrationTask, args))
if "perturbation" in args.tasks:
task_args.append(parse_task_args("perturbation", PerturbationTask, args))
if "cross_species" in args.tasks:
task_args.append(
parse_task_args("cross_species", CrossSpeciesIntegrationTask, args)
)
# Run the tasks
task_result = run(
dataset_names=args.datasets,
model_args=model_args,
task_args=task_args,
cache_options=cache_options,
)
task_results.extend(task_result)
# Write the results to the specified output
write_results(
task_results,
cache_options=cache_options,
output_format=args.output_format,
output_file=args.output_file,
)
[docs]
def run(
dataset_names: list[str],
model_args: list[ModelArgs],
task_args: list[TaskArgs],
cache_options: CacheOptions,
) -> list[TaskResult]:
"""
Run a set of tasks against a set of datasets. Runs inference if any `model_args` are specified.
"""
log.info(
f"Starting benchmarking batch for {len(dataset_names)} datasets, {len(model_args)} models, and {len(task_args)} tasks"
)
if not model_args:
return run_without_inference(dataset_names, task_args)
return run_with_inference(
dataset_names, model_args, task_args, cache_options=cache_options
)
[docs]
def run_with_inference(
dataset_names: list[str],
model_args: list[ModelArgs],
task_args: list[TaskArgs],
cache_options: CacheOptions,
) -> list[TaskResult]:
"""
Execute a series of tasks using multiple models on a collection of datasets.
This function handles the benchmarking process by iterating over the specified datasets,
running inference with the provided models to generate results, and running the tasks to evaluate
the generated outputs.
"""
task_results: list[TaskResult] = []
single_dataset_task_names = set(task_utils.TASK_NAMES) - set(
task_utils.MULTI_DATASET_TASK_NAMES
)
single_dataset_tasks: list[TaskArgs] = [
t for t in task_args if t.name in single_dataset_task_names
]
multi_dataset_tasks: list[TaskArgs] = [
t for t in task_args if t.name in task_utils.MULTI_DATASET_TASK_NAMES
]
embeddings_for_multi_dataset_tasks: dict[str, BaseDataset] = {}
# Get all unique combinations of model arguments: each requires a separate inference run
model_arg_permutations = get_model_arg_permutations(model_args)
if multi_dataset_tasks and not all(
len(ma) < 2 for ma in model_arg_permutations.values()
):
raise ValueError(
"Having multiple model_args for multi-dataset tasks is not supported"
)
for dataset_idx, dataset_name in enumerate(dataset_names):
log.info(
f'Processing dataset "{dataset_name}" ({dataset_idx + 1}/{len(dataset_names)})'
)
for model_name, model_arg_permutation in model_arg_permutations.items():
for args_idx, args in enumerate(model_arg_permutation):
log.info(
f'Starting model inference "{model_name}" ({args_idx + 1}/{len(model_arg_permutation)}) '
f'for dataset "{dataset_name}" ({args})'
)
processed_dataset = run_inference_or_load_from_cache(
dataset_name,
model_name=model_name,
model_args=args,
cache_options=cache_options,
)
# NOTE: accumulating datasets with attached embeddings in memory
# can be memory intensive
if multi_dataset_tasks:
embeddings_for_multi_dataset_tasks[dataset_name] = processed_dataset
# Run each single-dataset task against the processed dataset
for task_arg_idx, task_arg in enumerate(single_dataset_tasks):
log.info(
f'Starting task "{task_arg.name}" ({task_arg_idx + 1}/{len(task_args)}) for '
f'dataset "{dataset_name}" and model "{model_name}" ({task_arg})'
)
task_result = run_task(
dataset_name, processed_dataset, {model_name: args}, task_arg
)
task_results.extend(task_result)
# Run multi-dataset tasks
embeddings: list[BaseDataset] = list(embeddings_for_multi_dataset_tasks.values())
for task_arg_idx, task_arg in enumerate(multi_dataset_tasks):
log.info(
f'Starting multi-dataset task "{task_arg.name}" ({task_arg_idx + 1}/{len(task_args)}) for datasets "{dataset_names}"'
)
model_args_for_run = {
model_name: permutation[0]
for model_name, permutation in model_arg_permutations.items()
if len(permutation) == 1
}
task_result = run_multi_dataset_task(
dataset_names, embeddings, model_args_for_run, task_arg
)
task_results.extend(task_result)
return task_results
[docs]
def run_inference_or_load_from_cache(
dataset_name: str,
*,
model_name: str,
model_args: ModelArgsDict,
cache_options: CacheOptions,
) -> BaseDataset:
"""
Load the processed dataset from the cache if it exists, else run inference and save to cache.
"""
processed_dataset = try_processed_datasets_cache(
dataset_name,
model_name=model_name,
model_args=model_args,
cache_options=cache_options,
)
if processed_dataset:
log.info("Processed dataset is cached: skipping inference")
return processed_dataset
dataset = dataset_utils.load_dataset(dataset_name)
processed_dataset = runner.run_inference(
model_name,
dataset,
gpu=True,
**model_args, # type: ignore [arg-type]
)
# if we ran inference, put the embeddings produced into the cache (local and possibly remote)
set_processed_datasets_cache(
processed_dataset,
dataset_name,
model_name=model_name,
model_args=model_args,
cache_options=cache_options,
)
return processed_dataset
[docs]
def run_without_inference(
dataset_names: list[str], task_args: list[TaskArgs]
) -> list[TaskResult]:
"""
Run a set of tasks directly against raw datasets without first running model inference.
"""
task_results: list[TaskResult] = []
single_dataset_task_names = set(task_utils.TASK_NAMES) - set(
task_utils.MULTI_DATASET_TASK_NAMES
)
single_dataset_tasks: list[TaskArgs] = [
t for t in task_args if t.name in single_dataset_task_names
]
multi_dataset_tasks: list[TaskArgs] = [
t for t in task_args if t.name in task_utils.MULTI_DATASET_TASK_NAMES
]
embeddings_for_multi_dataset_tasks: dict[str, BaseDataset] = {}
for dataset_idx, dataset_name in enumerate(dataset_names):
log.info(
f'Processing dataset "{dataset_name}" ({dataset_idx + 1}/{len(dataset_names)})'
)
dataset = dataset_utils.load_dataset(dataset_name)
# NOTE: accumulating datasets with attached embeddings in memory
# can be memory intensive
if multi_dataset_tasks:
embeddings_for_multi_dataset_tasks[dataset_name] = dataset
for task_arg_idx, task_arg in enumerate(single_dataset_tasks):
log.info(
f'Starting task "{task_arg.name}" ({task_arg_idx + 1}/{len(task_args)}) for dataset "{dataset_name}"'
)
task_result = run_task(dataset_name, dataset, {}, task_arg)
task_results.extend(task_result)
# Run multi-dataset tasks
embeddings: list[BaseDataset] = list(embeddings_for_multi_dataset_tasks.values())
for task_arg_idx, task_arg in enumerate(multi_dataset_tasks):
log.info(
f'Starting multi-dataset task "{task_arg.name}" ({task_arg_idx + 1}/{len(task_args)}) for datasets "{dataset_names}"'
)
task_result = run_multi_dataset_task(dataset_names, embeddings, {}, task_arg)
task_results.extend(task_result)
return task_results
[docs]
def run_multi_dataset_task(
dataset_names: list[str],
embeddings: list[BaseDataset],
model_args: dict[str, ModelArgsDict],
task_args: TaskArgs,
) -> list[TaskResult]:
"""
Run a task and return the results.
"""
task_results: list[TaskResult] = []
if task_args.set_baseline:
raise ValueError("Baseline embedding run not allowed for multi-dataset tasks")
result: dict[ModelType, list[MetricResult]] = task_args.task.run(embeddings)
if not isinstance(result, Mapping):
raise TypeError("Expect a Map ADT for a task result")
# sorting the dataset_names for the purposes of using it as a
# cache key and uniform presentation to the user
dataset_names.sort()
for model_type, metrics in result.items():
task_result = TaskResult(
task_name=task_args.name,
task_name_display=task_args.task.display_name,
model_type=model_type.value,
dataset_names=dataset_names,
dataset_names_display=[
dataset_utils.dataset_to_display_name(ds) for ds in dataset_names
],
model_args=model_args.get(model_type.value) or {},
metrics=metrics,
)
task_results.append(task_result)
log.info(task_result)
return task_results
[docs]
def run_task(
dataset_name: str,
dataset: BaseDataset,
model_args: dict[str, ModelArgsDict],
task_args: TaskArgs,
) -> list[TaskResult]:
"""
Run a task and return the results.
"""
task_results: list[TaskResult] = []
if task_args.set_baseline:
dataset.load_data()
task_args.task.set_baseline(dataset, **task_args.baseline_args)
result: dict[ModelType, list[MetricResult]] = task_args.task.run(dataset)
if isinstance(result, list):
raise TypeError("Expected a single task result, got list")
for model_type, metrics in result.items():
if model_type == ModelType.BASELINE:
model_args_to_store = task_args.baseline_args
else:
model_args_to_store = model_args.get(model_type.value) or {}
task_result = TaskResult(
task_name=task_args.name,
task_name_display=task_args.task.display_name,
model_type=model_type.value,
dataset_names=[dataset_name],
dataset_names_display=[dataset_utils.dataset_to_display_name(dataset_name)],
model_args=model_args_to_store,
metrics=metrics,
)
task_results.append(task_result)
log.info(task_result)
return task_results
[docs]
def get_model_arg_permutations(
model_args: list[ModelArgs],
) -> dict[str, list[ModelArgsDict]]:
"""
Generate all the "permutations" of model arguments we want to run for each dataset:
E.g. Running 2 variants of scgenept at 2 chunk sizes results in 4 permutations
"""
result: dict[str, list[ModelArgsDict]] = defaultdict(list)
for model_arg in model_args:
if not model_arg.args:
result[model_arg.name] = [{}]
continue
keys, values = zip(*model_arg.args.items())
permutations: list[dict[str, str | int]] = [
{k: v for k, v in zip(keys, permutation)}
for permutation in itertools.product(*values)
]
result[model_arg.name] = permutations
return result
[docs]
def write_results(
task_results: list[TaskResult],
*,
cache_options: CacheOptions,
output_format: str = DEFAULT_OUTPUT_FORMAT,
output_file: str | None = None, # Writes to stdout if None
) -> None:
"""
Format and write results to the given directory or file.
"""
results_dict = {
"czbenchmarks_version": cli.get_version(),
"args": "czbenchmarks " + " ".join(sys.argv[1:]),
"task_results": [result.model_dump(mode="json") for result in task_results],
}
# Get the intended format/extension
if output_file and output_file.endswith(".json"):
output_format = "json"
elif output_file and (
output_file.endswith(".yaml") or output_file.endswith(".yml")
):
output_format = "yaml"
elif output_format not in VALID_OUTPUT_FORMATS:
raise ValueError(f"Invalid output format: {output_format}")
results_str = ""
if output_format == "json":
results_str = json.dumps(results_dict, indent=2)
else:
results_str = yaml.dump(results_dict)
if cache_options.remote_cache_url and cache_options.upload_results:
remote_url = get_result_url_for_remote(cache_options.remote_cache_url)
try:
utils.upload_blob_to_remote(
results_str.encode("utf-8"), remote_url, overwrite_existing=False
)
except exceptions.RemoteStorageError:
log.exception(f"Failed to upload results to {remote_url!r}")
log.info("Uploaded results to %r", remote_url)
# Generate a unique filename if we were passed a directory
if output_file and (os.path.isdir(output_file) or output_file.endswith("/")):
current_time = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
output_file = os.path.join(
output_file, f"czbenchmarks_results_{current_time}.{output_format}"
)
if output_file:
with open(output_file, "w") as f:
f.write(results_str)
f.write("\n")
log.info("Wrote results to %r", output_file)
else:
# Write to stdout if not otherwise specified
sys.stdout.write(results_str)
sys.stdout.write("\n")
[docs]
def get_result_url_for_remote(remote_prefix_url: str) -> str:
nonce = token_hex(4)
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
version = cli.get_version()
return f"{remote_prefix_url.rstrip('/')}/{version}/results/{timestamp}-{nonce}.json"
[docs]
def set_processed_datasets_cache(
dataset: BaseDataset,
dataset_name: str,
*,
model_name: str,
model_args: ModelArgsDict,
cache_options: CacheOptions,
) -> None:
"""
Write a dataset to the cache
A "processed" dataset has been run with model inference for the given arguments.
"""
dataset_filename = get_processed_dataset_cache_filename(
dataset_name, model_name=model_name, model_args=model_args
)
cache_dir = Path(PROCESSED_DATASETS_CACHE_PATH).expanduser().absolute()
cache_file = cache_dir / dataset_filename
try:
# "Unload" the source data so we only cache the results
dataset.unload_data()
cache_dir.mkdir(parents=True, exist_ok=True)
dataset.serialize(str(cache_file))
succeeded = True
except Exception as e:
# Log the exception, but don't raise if we can't write to the cache for some reason
log.exception(
f'Failed to serialize processed dataset to cache "{cache_file}": {e}'
)
succeeded = False
if succeeded and cache_options.upload_embeddings:
# upload the new embeddings, overwriting any that may already exist
remote_prefix = get_remote_cache_prefix(cache_options)
try:
utils.upload_file_to_remote(
cache_file, remote_prefix, overwrite_existing=True
)
log.info(f"Uploaded processed dataset from {cache_file} to {remote_prefix}")
except exceptions.RemoteStorageError:
log.exception("Unable to upload processed dataset to remote cache")
dataset.load_data()
[docs]
def try_processed_datasets_cache(
dataset_name: str,
*,
model_name: str,
model_args: ModelArgsDict,
cache_options: CacheOptions,
) -> BaseDataset | None:
"""
Deserialize and return a processed dataset from the cache if it exists, else return None.
"""
dataset_filename = get_processed_dataset_cache_filename(
dataset_name, model_name=model_name, model_args=model_args
)
cache_dir = Path(PROCESSED_DATASETS_CACHE_PATH).expanduser().absolute()
cache_file = cache_dir / dataset_filename
if cache_options.download_embeddings:
# check the remote cache and download the file if a local version doesn't
# exist, or if the remote version is newer than the local version
remote_url = f"{get_remote_cache_prefix(cache_options)}{dataset_filename}"
local_modified: datetime | None = None
remote_modified: datetime | None = None
if cache_file.exists():
local_modified = datetime.fromtimestamp(
cache_file.stat().st_mtime, tz=timezone.utc
)
try:
remote_modified = utils.get_remote_last_modified(
remote_url, make_unsigned_request=False
)
except exceptions.RemoteStorageError:
# not a great way to handle this, but maybe the cache bucket is not public
try:
log.warning(
"Unsigned request to remote storage cache failed. Trying signed request."
)
remote_modified = utils.get_remote_last_modified(
remote_url, make_unsigned_request=True
)
except exceptions.RemoteStorageError:
pass
if remote_modified is None:
log.info("Remote cached embeddings don't exist. Skipping download.")
elif local_modified is not None and (remote_modified <= local_modified):
log.info(
f"Remote cached embeddings modified at {remote_modified}. "
f"Local cache files modified more recently at {local_modified}. "
"Skipping download."
)
else:
try:
utils.download_file_from_remote(remote_url, cache_dir)
log.info(
f"Downloaded cached embeddings from {remote_url} to {cache_dir}"
)
except exceptions.RemoteStorageError:
# not a great way to handle this, but maybe the cache bucket is not public
try:
log.warning(
"Unsigned request to remote storage cache failed. Trying signed request."
)
utils.download_file_from_remote(
remote_url, cache_dir, make_unsigned_request=False
)
log.info(
f"Downloaded cached embeddings from {remote_url} to {cache_dir}"
)
except exceptions.RemoteStorageError:
log.warning(
f"Unable to retrieve embeddings from remote cache at {remote_url!r}"
)
if cache_file.exists():
# Load the original dataset
dataset = dataset_utils.load_dataset(dataset_name)
dataset.load_data()
# Attach the cached results to the dataset
processed_dataset = BaseDataset.deserialize(str(cache_file))
dataset._outputs = processed_dataset._outputs
return dataset
return None
[docs]
def get_remote_cache_prefix(cache_options: CacheOptions):
"""get the prefix ending in '/' that the remote processed datasets go under"""
return f"{cache_options.remote_cache_url.rstrip('/')}/{cli.get_version()}/processed-datasets/"
[docs]
def get_processed_dataset_cache_filename(
dataset_name: str, *, model_name: str, model_args: ModelArgsDict
) -> str:
"""
generate a unique filename for the given dataset and model arguments
"""
if model_args:
model_args_str = f"{model_name}_" + "_".join(
f"{k}-{v}" for k, v in sorted(model_args.items())
)
else:
model_args_str = model_name
filename = f"{dataset_name}_{model_args_str}.dill"
return filename
[docs]
def get_processed_dataset_cache_path(
dataset_name: str, *, model_name: str, model_args: ModelArgsDict
) -> Path:
"""
Return a unique file path in the cache directory for the given dataset and model arguments.
"""
cache_dir = Path(PROCESSED_DATASETS_CACHE_PATH).expanduser().absolute()
filename = get_processed_dataset_cache_filename(
dataset_name, model_name=model_name, model_args=model_args
)
return cache_dir / filename
[docs]
def parse_model_args(model_name: str, args: argparse.Namespace) -> ModelArgs:
"""
Populate a ModelArgs instance from the given argparse namespace.
"""
prefix = f"{model_name.lower()}_"
model_args: dict[str, Any] = {}
for k, v in vars(args).items():
if v is not None and k.startswith(prefix):
model_args[k.removeprefix(prefix)] = v
return ModelArgs(name=model_name.upper(), args=model_args)
[docs]
def parse_task_args(
task_name: str, TaskCls: type[TaskType], args: argparse.Namespace
) -> TaskArgs:
"""
Populate a TaskArgs instance from the given argparse namespace.
"""
prefix = f"{task_name.lower()}_task_"
task_args: dict[str, Any] = {}
baseline_args: dict[str, Any] = {}
for k, v in vars(args).items():
if v is not None and k.startswith(prefix):
trimmed_k = k.removeprefix(prefix)
if trimmed_k.startswith("baseline_"):
baseline_args[trimmed_k.removeprefix("baseline_")] = v
else:
task_args[trimmed_k] = v
set_baseline = task_args.pop("set_baseline", False)
return TaskArgs(
name=task_name,
task=TaskCls(**task_args),
set_baseline=set_baseline,
baseline_args=baseline_args,
)
[docs]
def parse_batch_json(batch_json_list: list[str]) -> list[dict[str, Any]]:
"""
Parse the `--batch-json` argument.
Returns a list of dicts where each entry is a batch of CLI arguments.
"""
batches: list[dict[str, Any]] = []
if not batch_json_list:
return [{}]
for batch_json in batch_json_list:
if not batch_json.strip():
batches.append({})
continue
# Load JSON from disk if we were given a valid file path
if os.path.isfile(batch_json):
try:
with open(batch_json, "r") as f:
batches.append(json.load(f))
except Exception as e:
raise ValueError(
f"Failed to load batch JSON from file {batch_json}: {e}"
) from e
continue
# Otherwise treat the input as JSON
try:
result = json.loads(batch_json)
if isinstance(result, list):
batches.extend(result)
elif isinstance(result, dict):
batches.append(result)
else:
raise ValueError(
"Invalid batch JSON: input must be a dictionary of CLI arguments"
)
continue
except json.JSONDecodeError as e:
raise ValueError(f"Invalid batch JSON {batch_json}: {e}") from e
return batches