Source code for czbenchmarks.cli.cli_run

import argparse
import itertools
import json
import logging
import os
import sys
import yaml

from collections import defaultdict
from collections.abc import Mapping
from copy import deepcopy
from datetime import datetime, timezone
from pathlib import Path
from pydantic import BaseModel, computed_field
from secrets import token_hex
from typing import Any, Generic, TypeVar

from czbenchmarks import runner
from czbenchmarks.cli import cli
from czbenchmarks.constants import PROCESSED_DATASETS_CACHE_PATH
from czbenchmarks.datasets import utils as dataset_utils
from czbenchmarks.datasets.base import BaseDataset
from czbenchmarks import exceptions
from czbenchmarks.metrics.types import MetricResult
from czbenchmarks.models import utils as model_utils
from czbenchmarks.models.types import ModelType
from czbenchmarks.tasks import utils as task_utils
from czbenchmarks.tasks.base import BaseTask
from czbenchmarks.tasks.clustering import ClusteringTask
from czbenchmarks.tasks.embedding import EmbeddingTask
from czbenchmarks.tasks.integration import BatchIntegrationTask
from czbenchmarks.tasks.label_prediction import MetadataLabelPredictionTask
from czbenchmarks.tasks.single_cell.cross_species import CrossSpeciesIntegrationTask
from czbenchmarks.tasks.single_cell.perturbation import PerturbationTask
from czbenchmarks import utils


log = logging.getLogger(__name__)

VALID_OUTPUT_FORMATS = ["json", "yaml"]
DEFAULT_OUTPUT_FORMAT = "json"

TaskType = TypeVar("TaskType", bound=BaseTask)
ModelArgsDict = dict[str, str | int]  # Arguments passed to model inference
RuntimeMetricsDict = dict[
    str, str | int | float
]  # runtime metrics like elapsed time or CPU count, not implemented yet


[docs] class ModelArgs(BaseModel): name: str # Upper-case model name e.g. SCVI args: dict[str, list[str | int]] # Args forwarded to the model container
[docs] class TaskArgs(BaseModel, Generic[TaskType]): model_config = {"arbitrary_types_allowed": True} # Required to support TaskType name: str # Lower-case task name e.g. embedding task: TaskType set_baseline: bool baseline_args: dict[str, Any]
[docs] class TaskResult(BaseModel): task_name: str task_name_display: str model_type: ModelType dataset_names: list[str] dataset_names_display: list[str] model_args: ModelArgsDict metrics: list[MetricResult] runtime_metrics: RuntimeMetricsDict = {} # not implementing any of these for now @computed_field @property def model_name_display(self) -> str: return model_utils.model_to_display_name(self.model_type, self.model_args)
[docs] class CacheOptions(BaseModel): download_embeddings: bool upload_embeddings: bool upload_results: bool remote_cache_url: str
[docs] @classmethod def from_args(cls, args: argparse.Namespace) -> "CacheOptions": remote_cache_url = args.remote_cache_url or "" return cls( remote_cache_url=remote_cache_url, download_embeddings=bool(remote_cache_url) and args.remote_cache_download_embeddings, upload_embeddings=bool(remote_cache_url) and args.remote_cache_upload_embeddings, upload_results=bool(remote_cache_url) and args.remote_cache_upload_results, )
[docs] def add_arguments(parser: argparse.ArgumentParser) -> None: """ Add run command arguments to the parser. """ parser.add_argument( "--models", "-m", nargs="+", choices=model_utils.list_available_models(), help="One or more model names (from models.yaml).", ) parser.add_argument( "--datasets", "-d", nargs="+", choices=dataset_utils.list_available_datasets(), help="One or more dataset names (from datasets.yaml).", ) parser.add_argument( "--tasks", "-t", nargs="+", choices=task_utils.TASK_NAMES, help="One or more tasks to run.", ) parser.add_argument( "--output-format", "-fmt", choices=VALID_OUTPUT_FORMATS, default="yaml", help="Output format for results (ignored if --output-file specifies a valid file extension)", ) parser.add_argument( "--output-file", "-o", help="Path to file or directory to save results (default is stdout)", ) parser.add_argument( "--remote-cache-url", help=( "AWS S3 URL prefix for caching embeddings and storing output " "(example: s3://cz-benchmarks-example/). Files will be stored " "underneath the current --version number. This alone will not " "trigger any caching behavior unless one or more of the " "--remote-cache-download-embeddings, --remote-cache-upload-embeddings " "or --remote-cache-upload-results flags are specified." ), ) parser.add_argument( "--remote-cache-download-embeddings", action="store_true", help=( "If specified, download embeddings from the remote cache to " "PROCESSED_DATASETS_CACHE_PATH if local versions do not exist " "or are older than those in the remote cache. Only embeddings " "matching the current version will be downloaded." ), default=False, ) parser.add_argument( "--remote-cache-upload-embeddings", action="store_true", help=( "Upload any processed embeddings produced to the remote cache, overwriting " "any that may already exist there for the current version. They will be " "stored under s3://<remote_cache_url>/<version>/processed-datasets/*.dill" ), default=False, ) parser.add_argument( "--remote-cache-upload-results", action="store_true", help=( "Upload the results to the remote cache. This allows results " "to be shared across instances. They will be stored under " "s3://<remote_cache_url>/<version>/results/<timestamp>-<random_hex>.json" ), default=False, ) # Extra arguments for geneformer model parser.add_argument( "--geneformer-model-variant", nargs="+", help="Variant of the geneformer model to use (see docker/geneformer/config.yaml)", ) # Extra arguments for scgenept model parser.add_argument( "--scgenept-model-variant", nargs="+", help="Variant of the scgenept model to use (see docker/scgenept/config.yaml)", ) parser.add_argument( "--scgenept-gene-pert", nargs="+", help="Gene perturbation to use for scgenept model", ) parser.add_argument( "--scgenept-dataset-name", nargs="+", help="Dataset name to use for scgenept model", ) parser.add_argument( "--scgenept-chunk-size", type=int, nargs="+", help="Chunk size to use for scgenept model", ) # Extra arguments for scgpt model parser.add_argument( "--scgpt-model-variant", nargs="+", help="Variant of the scgpt model to use (see docker/scgpt/config.yaml)", ) # Extra arguments for scvi model parser.add_argument( "--scvi-model-variant", nargs="+", help="Variant of the scvi model to use (see docker/scvi/config.yaml)", ) # Extra arguments for uce model parser.add_argument( "--uce-model-variant", nargs="+", help="Variant of the uce model to use (see docker/uce/config.yaml)", ) # Extra arguments for transcriptformer model parser.add_argument( "--transcriptformer-model-variant", nargs="+", choices=["tf-sapiens", "tf-exemplar", "tf-metazoa"], help="Variant of the transcriptformer model to use (tf-sapiens, tf-exemplar, tf-metazoa)", ) parser.add_argument( "--transcriptformer-batch-size", type=int, nargs="+", help="Batch size for transcriptformer model inference", ) # Extra arguments for AIDO model parser.add_argument( "--aido-model-variant", nargs="*", choices=["aido_cell_3m", "aido_cell_10m", "aido_cell_100m"], default="aido_cell_3m", help="Variant of the aido model to use. Default is aido_cell_3m", ) parser.add_argument( "--aido-batch-size", type=int, nargs="*", help="Batch size for AIDO model inference (optional)", ) # Extra arguments for clustering task parser.add_argument( "--clustering-task-label-key", help="Label key to use for clustering task", ) parser.add_argument( "--clustering-task-set-baseline", action="store_true", help="Preprocess dataset and set PCA embedding as the BASELINE model output in the dataset", ) # Extra arguments for embedding task parser.add_argument( "--embedding-task-label-key", help="Label key to use for embedding task", ) parser.add_argument( "--embedding-task-set-baseline", action="store_true", help="Preprocess dataset and set PCA embedding as the BASELINE model output in the dataset", ) # Extra arguments for label prediction task parser.add_argument( "--label-prediction-task-label-key", help="Label key to use for label prediction task", ) parser.add_argument( "--label-prediction-task-set-baseline", action="store_true", help="Preprocess dataset and set PCA embedding as the BASELINE model output in the dataset", ) parser.add_argument( "--label-prediction-task-n-folds", type=int, help="Number of cross-validation folds (optional)", ) parser.add_argument( "--label-prediction-task-seed", type=int, help="Random seed for reproducibility (optional)", ) parser.add_argument( "--label-prediction-task-min-class-size", type=int, help="Minimum samples required per class (optional)", ) # Extra arguments for integration task parser.add_argument( "--integration-task-label-key", help="Label key to use for integration task", ) parser.add_argument( "--integration-task-set-baseline", action="store_true", help="Use raw gene expression matrix as features for classification (instead of embeddings)", ) parser.add_argument( "--integration-task-batch-key", help="Key to access batch labels in metadata", ) # Extra arguments for cross species integration task parser.add_argument( "--cross-species-task-label-key", help="Label key to use for cross species integration task", ) # Extra arguments for perturbation task parser.add_argument( "--perturbation-task-set-baseline", action="store_true", help="Use mean and median predictions as the BASELINE model output in the dataset", ) parser.add_argument( "--perturbation-task-baseline-gene-pert", type=str, help="Gene perturbation to use for baseline", ) # Advanced feature: define multiple batches of jobs using JSON parser.add_argument( "--batch-json", "-b", nargs="+", default=[""], help='Override CLI arguments from the given JSON, e.g. \'{"output_file": "..."}\'. Can be set multiple times to run complex "batch" jobs.', )
[docs] def main(parsed_args: argparse.Namespace) -> None: """ Execute a series of tasks using multiple models on a collection of datasets. This function handles the benchmarking process by iterating over the specified datasets, running inference with the provided models to generate results, and running the tasks to evaluate the generated outputs. """ task_results: list[TaskResult] = [] batch_args = parse_batch_json(parsed_args.batch_json) cache_options = CacheOptions.from_args(parsed_args) for batch_idx, batch_dict in enumerate(batch_args): log.info(f"Starting batch {batch_idx + 1}/{len(parsed_args.batch_json)}") args = deepcopy(parsed_args) for batch_key, batch_val in batch_dict.items(): setattr(args, batch_key, batch_val) # Collect all the arguments that we'll need to pass directly to each model model_args: list[ModelArgs] = [] for model_name in args.models or []: model_args.append(parse_model_args(model_name.lower(), args)) # Collect all the task-related arguments task_args: list[TaskArgs] = [] if "clustering" in args.tasks: task_args.append(parse_task_args("clustering", ClusteringTask, args)) if "embedding" in args.tasks: task_args.append(parse_task_args("embedding", EmbeddingTask, args)) if "label_prediction" in args.tasks: task_args.append( parse_task_args("label_prediction", MetadataLabelPredictionTask, args) ) if "integration" in args.tasks: task_args.append(parse_task_args("integration", BatchIntegrationTask, args)) if "perturbation" in args.tasks: task_args.append(parse_task_args("perturbation", PerturbationTask, args)) if "cross_species" in args.tasks: task_args.append( parse_task_args("cross_species", CrossSpeciesIntegrationTask, args) ) # Run the tasks task_result = run( dataset_names=args.datasets, model_args=model_args, task_args=task_args, cache_options=cache_options, ) task_results.extend(task_result) # Write the results to the specified output write_results( task_results, cache_options=cache_options, output_format=args.output_format, output_file=args.output_file, )
[docs] def run( dataset_names: list[str], model_args: list[ModelArgs], task_args: list[TaskArgs], cache_options: CacheOptions, ) -> list[TaskResult]: """ Run a set of tasks against a set of datasets. Runs inference if any `model_args` are specified. """ log.info( f"Starting benchmarking batch for {len(dataset_names)} datasets, {len(model_args)} models, and {len(task_args)} tasks" ) if not model_args: return run_without_inference(dataset_names, task_args) return run_with_inference( dataset_names, model_args, task_args, cache_options=cache_options )
[docs] def run_with_inference( dataset_names: list[str], model_args: list[ModelArgs], task_args: list[TaskArgs], cache_options: CacheOptions, ) -> list[TaskResult]: """ Execute a series of tasks using multiple models on a collection of datasets. This function handles the benchmarking process by iterating over the specified datasets, running inference with the provided models to generate results, and running the tasks to evaluate the generated outputs. """ task_results: list[TaskResult] = [] single_dataset_task_names = set(task_utils.TASK_NAMES) - set( task_utils.MULTI_DATASET_TASK_NAMES ) single_dataset_tasks: list[TaskArgs] = [ t for t in task_args if t.name in single_dataset_task_names ] multi_dataset_tasks: list[TaskArgs] = [ t for t in task_args if t.name in task_utils.MULTI_DATASET_TASK_NAMES ] embeddings_for_multi_dataset_tasks: dict[str, BaseDataset] = {} # Get all unique combinations of model arguments: each requires a separate inference run model_arg_permutations = get_model_arg_permutations(model_args) if multi_dataset_tasks and not all( len(ma) < 2 for ma in model_arg_permutations.values() ): raise ValueError( "Having multiple model_args for multi-dataset tasks is not supported" ) for dataset_idx, dataset_name in enumerate(dataset_names): log.info( f'Processing dataset "{dataset_name}" ({dataset_idx + 1}/{len(dataset_names)})' ) for model_name, model_arg_permutation in model_arg_permutations.items(): for args_idx, args in enumerate(model_arg_permutation): log.info( f'Starting model inference "{model_name}" ({args_idx + 1}/{len(model_arg_permutation)}) ' f'for dataset "{dataset_name}" ({args})' ) processed_dataset = run_inference_or_load_from_cache( dataset_name, model_name=model_name, model_args=args, cache_options=cache_options, ) # NOTE: accumulating datasets with attached embeddings in memory # can be memory intensive if multi_dataset_tasks: embeddings_for_multi_dataset_tasks[dataset_name] = processed_dataset # Run each single-dataset task against the processed dataset for task_arg_idx, task_arg in enumerate(single_dataset_tasks): log.info( f'Starting task "{task_arg.name}" ({task_arg_idx + 1}/{len(task_args)}) for ' f'dataset "{dataset_name}" and model "{model_name}" ({task_arg})' ) task_result = run_task( dataset_name, processed_dataset, {model_name: args}, task_arg ) task_results.extend(task_result) # Run multi-dataset tasks embeddings: list[BaseDataset] = list(embeddings_for_multi_dataset_tasks.values()) for task_arg_idx, task_arg in enumerate(multi_dataset_tasks): log.info( f'Starting multi-dataset task "{task_arg.name}" ({task_arg_idx + 1}/{len(task_args)}) for datasets "{dataset_names}"' ) model_args_for_run = { model_name: permutation[0] for model_name, permutation in model_arg_permutations.items() if len(permutation) == 1 } task_result = run_multi_dataset_task( dataset_names, embeddings, model_args_for_run, task_arg ) task_results.extend(task_result) return task_results
[docs] def run_inference_or_load_from_cache( dataset_name: str, *, model_name: str, model_args: ModelArgsDict, cache_options: CacheOptions, ) -> BaseDataset: """ Load the processed dataset from the cache if it exists, else run inference and save to cache. """ processed_dataset = try_processed_datasets_cache( dataset_name, model_name=model_name, model_args=model_args, cache_options=cache_options, ) if processed_dataset: log.info("Processed dataset is cached: skipping inference") return processed_dataset dataset = dataset_utils.load_dataset(dataset_name) processed_dataset = runner.run_inference( model_name, dataset, gpu=True, **model_args, # type: ignore [arg-type] ) # if we ran inference, put the embeddings produced into the cache (local and possibly remote) set_processed_datasets_cache( processed_dataset, dataset_name, model_name=model_name, model_args=model_args, cache_options=cache_options, ) return processed_dataset
[docs] def run_without_inference( dataset_names: list[str], task_args: list[TaskArgs] ) -> list[TaskResult]: """ Run a set of tasks directly against raw datasets without first running model inference. """ task_results: list[TaskResult] = [] single_dataset_task_names = set(task_utils.TASK_NAMES) - set( task_utils.MULTI_DATASET_TASK_NAMES ) single_dataset_tasks: list[TaskArgs] = [ t for t in task_args if t.name in single_dataset_task_names ] multi_dataset_tasks: list[TaskArgs] = [ t for t in task_args if t.name in task_utils.MULTI_DATASET_TASK_NAMES ] embeddings_for_multi_dataset_tasks: dict[str, BaseDataset] = {} for dataset_idx, dataset_name in enumerate(dataset_names): log.info( f'Processing dataset "{dataset_name}" ({dataset_idx + 1}/{len(dataset_names)})' ) dataset = dataset_utils.load_dataset(dataset_name) # NOTE: accumulating datasets with attached embeddings in memory # can be memory intensive if multi_dataset_tasks: embeddings_for_multi_dataset_tasks[dataset_name] = dataset for task_arg_idx, task_arg in enumerate(single_dataset_tasks): log.info( f'Starting task "{task_arg.name}" ({task_arg_idx + 1}/{len(task_args)}) for dataset "{dataset_name}"' ) task_result = run_task(dataset_name, dataset, {}, task_arg) task_results.extend(task_result) # Run multi-dataset tasks embeddings: list[BaseDataset] = list(embeddings_for_multi_dataset_tasks.values()) for task_arg_idx, task_arg in enumerate(multi_dataset_tasks): log.info( f'Starting multi-dataset task "{task_arg.name}" ({task_arg_idx + 1}/{len(task_args)}) for datasets "{dataset_names}"' ) task_result = run_multi_dataset_task(dataset_names, embeddings, {}, task_arg) task_results.extend(task_result) return task_results
[docs] def run_multi_dataset_task( dataset_names: list[str], embeddings: list[BaseDataset], model_args: dict[str, ModelArgsDict], task_args: TaskArgs, ) -> list[TaskResult]: """ Run a task and return the results. """ task_results: list[TaskResult] = [] if task_args.set_baseline: raise ValueError("Baseline embedding run not allowed for multi-dataset tasks") result: dict[ModelType, list[MetricResult]] = task_args.task.run(embeddings) if not isinstance(result, Mapping): raise TypeError("Expect a Map ADT for a task result") # sorting the dataset_names for the purposes of using it as a # cache key and uniform presentation to the user dataset_names.sort() for model_type, metrics in result.items(): task_result = TaskResult( task_name=task_args.name, task_name_display=task_args.task.display_name, model_type=model_type.value, dataset_names=dataset_names, dataset_names_display=[ dataset_utils.dataset_to_display_name(ds) for ds in dataset_names ], model_args=model_args.get(model_type.value) or {}, metrics=metrics, ) task_results.append(task_result) log.info(task_result) return task_results
[docs] def run_task( dataset_name: str, dataset: BaseDataset, model_args: dict[str, ModelArgsDict], task_args: TaskArgs, ) -> list[TaskResult]: """ Run a task and return the results. """ task_results: list[TaskResult] = [] if task_args.set_baseline: dataset.load_data() task_args.task.set_baseline(dataset, **task_args.baseline_args) result: dict[ModelType, list[MetricResult]] = task_args.task.run(dataset) if isinstance(result, list): raise TypeError("Expected a single task result, got list") for model_type, metrics in result.items(): if model_type == ModelType.BASELINE: model_args_to_store = task_args.baseline_args else: model_args_to_store = model_args.get(model_type.value) or {} task_result = TaskResult( task_name=task_args.name, task_name_display=task_args.task.display_name, model_type=model_type.value, dataset_names=[dataset_name], dataset_names_display=[dataset_utils.dataset_to_display_name(dataset_name)], model_args=model_args_to_store, metrics=metrics, ) task_results.append(task_result) log.info(task_result) return task_results
[docs] def get_model_arg_permutations( model_args: list[ModelArgs], ) -> dict[str, list[ModelArgsDict]]: """ Generate all the "permutations" of model arguments we want to run for each dataset: E.g. Running 2 variants of scgenept at 2 chunk sizes results in 4 permutations """ result: dict[str, list[ModelArgsDict]] = defaultdict(list) for model_arg in model_args: if not model_arg.args: result[model_arg.name] = [{}] continue keys, values = zip(*model_arg.args.items()) permutations: list[dict[str, str | int]] = [ {k: v for k, v in zip(keys, permutation)} for permutation in itertools.product(*values) ] result[model_arg.name] = permutations return result
[docs] def write_results( task_results: list[TaskResult], *, cache_options: CacheOptions, output_format: str = DEFAULT_OUTPUT_FORMAT, output_file: str | None = None, # Writes to stdout if None ) -> None: """ Format and write results to the given directory or file. """ results_dict = { "czbenchmarks_version": cli.get_version(), "args": "czbenchmarks " + " ".join(sys.argv[1:]), "task_results": [result.model_dump(mode="json") for result in task_results], } # Get the intended format/extension if output_file and output_file.endswith(".json"): output_format = "json" elif output_file and ( output_file.endswith(".yaml") or output_file.endswith(".yml") ): output_format = "yaml" elif output_format not in VALID_OUTPUT_FORMATS: raise ValueError(f"Invalid output format: {output_format}") results_str = "" if output_format == "json": results_str = json.dumps(results_dict, indent=2) else: results_str = yaml.dump(results_dict) if cache_options.remote_cache_url and cache_options.upload_results: remote_url = get_result_url_for_remote(cache_options.remote_cache_url) try: utils.upload_blob_to_remote( results_str.encode("utf-8"), remote_url, overwrite_existing=False ) except exceptions.RemoteStorageError: log.exception(f"Failed to upload results to {remote_url!r}") log.info("Uploaded results to %r", remote_url) # Generate a unique filename if we were passed a directory if output_file and (os.path.isdir(output_file) or output_file.endswith("/")): current_time = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") output_file = os.path.join( output_file, f"czbenchmarks_results_{current_time}.{output_format}" ) if output_file: with open(output_file, "w") as f: f.write(results_str) f.write("\n") log.info("Wrote results to %r", output_file) else: # Write to stdout if not otherwise specified sys.stdout.write(results_str) sys.stdout.write("\n")
[docs] def get_result_url_for_remote(remote_prefix_url: str) -> str: nonce = token_hex(4) timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") version = cli.get_version() return f"{remote_prefix_url.rstrip('/')}/{version}/results/{timestamp}-{nonce}.json"
[docs] def set_processed_datasets_cache( dataset: BaseDataset, dataset_name: str, *, model_name: str, model_args: ModelArgsDict, cache_options: CacheOptions, ) -> None: """ Write a dataset to the cache A "processed" dataset has been run with model inference for the given arguments. """ dataset_filename = get_processed_dataset_cache_filename( dataset_name, model_name=model_name, model_args=model_args ) cache_dir = Path(PROCESSED_DATASETS_CACHE_PATH).expanduser().absolute() cache_file = cache_dir / dataset_filename try: # "Unload" the source data so we only cache the results dataset.unload_data() cache_dir.mkdir(parents=True, exist_ok=True) dataset.serialize(str(cache_file)) succeeded = True except Exception as e: # Log the exception, but don't raise if we can't write to the cache for some reason log.exception( f'Failed to serialize processed dataset to cache "{cache_file}": {e}' ) succeeded = False if succeeded and cache_options.upload_embeddings: # upload the new embeddings, overwriting any that may already exist remote_prefix = get_remote_cache_prefix(cache_options) try: utils.upload_file_to_remote( cache_file, remote_prefix, overwrite_existing=True ) log.info(f"Uploaded processed dataset from {cache_file} to {remote_prefix}") except exceptions.RemoteStorageError: log.exception("Unable to upload processed dataset to remote cache") dataset.load_data()
[docs] def try_processed_datasets_cache( dataset_name: str, *, model_name: str, model_args: ModelArgsDict, cache_options: CacheOptions, ) -> BaseDataset | None: """ Deserialize and return a processed dataset from the cache if it exists, else return None. """ dataset_filename = get_processed_dataset_cache_filename( dataset_name, model_name=model_name, model_args=model_args ) cache_dir = Path(PROCESSED_DATASETS_CACHE_PATH).expanduser().absolute() cache_file = cache_dir / dataset_filename if cache_options.download_embeddings: # check the remote cache and download the file if a local version doesn't # exist, or if the remote version is newer than the local version remote_url = f"{get_remote_cache_prefix(cache_options)}{dataset_filename}" local_modified: datetime | None = None remote_modified: datetime | None = None if cache_file.exists(): local_modified = datetime.fromtimestamp( cache_file.stat().st_mtime, tz=timezone.utc ) try: remote_modified = utils.get_remote_last_modified( remote_url, make_unsigned_request=False ) except exceptions.RemoteStorageError: # not a great way to handle this, but maybe the cache bucket is not public try: log.warning( "Unsigned request to remote storage cache failed. Trying signed request." ) remote_modified = utils.get_remote_last_modified( remote_url, make_unsigned_request=True ) except exceptions.RemoteStorageError: pass if remote_modified is None: log.info("Remote cached embeddings don't exist. Skipping download.") elif local_modified is not None and (remote_modified <= local_modified): log.info( f"Remote cached embeddings modified at {remote_modified}. " f"Local cache files modified more recently at {local_modified}. " "Skipping download." ) else: try: utils.download_file_from_remote(remote_url, cache_dir) log.info( f"Downloaded cached embeddings from {remote_url} to {cache_dir}" ) except exceptions.RemoteStorageError: # not a great way to handle this, but maybe the cache bucket is not public try: log.warning( "Unsigned request to remote storage cache failed. Trying signed request." ) utils.download_file_from_remote( remote_url, cache_dir, make_unsigned_request=False ) log.info( f"Downloaded cached embeddings from {remote_url} to {cache_dir}" ) except exceptions.RemoteStorageError: log.warning( f"Unable to retrieve embeddings from remote cache at {remote_url!r}" ) if cache_file.exists(): # Load the original dataset dataset = dataset_utils.load_dataset(dataset_name) dataset.load_data() # Attach the cached results to the dataset processed_dataset = BaseDataset.deserialize(str(cache_file)) dataset._outputs = processed_dataset._outputs return dataset return None
[docs] def get_remote_cache_prefix(cache_options: CacheOptions): """get the prefix ending in '/' that the remote processed datasets go under""" return f"{cache_options.remote_cache_url.rstrip('/')}/{cli.get_version()}/processed-datasets/"
[docs] def get_processed_dataset_cache_filename( dataset_name: str, *, model_name: str, model_args: ModelArgsDict ) -> str: """ generate a unique filename for the given dataset and model arguments """ if model_args: model_args_str = f"{model_name}_" + "_".join( f"{k}-{v}" for k, v in sorted(model_args.items()) ) else: model_args_str = model_name filename = f"{dataset_name}_{model_args_str}.dill" return filename
[docs] def get_processed_dataset_cache_path( dataset_name: str, *, model_name: str, model_args: ModelArgsDict ) -> Path: """ Return a unique file path in the cache directory for the given dataset and model arguments. """ cache_dir = Path(PROCESSED_DATASETS_CACHE_PATH).expanduser().absolute() filename = get_processed_dataset_cache_filename( dataset_name, model_name=model_name, model_args=model_args ) return cache_dir / filename
[docs] def parse_model_args(model_name: str, args: argparse.Namespace) -> ModelArgs: """ Populate a ModelArgs instance from the given argparse namespace. """ prefix = f"{model_name.lower()}_" model_args: dict[str, Any] = {} for k, v in vars(args).items(): if v is not None and k.startswith(prefix): model_args[k.removeprefix(prefix)] = v return ModelArgs(name=model_name.upper(), args=model_args)
[docs] def parse_task_args( task_name: str, TaskCls: type[TaskType], args: argparse.Namespace ) -> TaskArgs: """ Populate a TaskArgs instance from the given argparse namespace. """ prefix = f"{task_name.lower()}_task_" task_args: dict[str, Any] = {} baseline_args: dict[str, Any] = {} for k, v in vars(args).items(): if v is not None and k.startswith(prefix): trimmed_k = k.removeprefix(prefix) if trimmed_k.startswith("baseline_"): baseline_args[trimmed_k.removeprefix("baseline_")] = v else: task_args[trimmed_k] = v set_baseline = task_args.pop("set_baseline", False) return TaskArgs( name=task_name, task=TaskCls(**task_args), set_baseline=set_baseline, baseline_args=baseline_args, )
[docs] def parse_batch_json(batch_json_list: list[str]) -> list[dict[str, Any]]: """ Parse the `--batch-json` argument. Returns a list of dicts where each entry is a batch of CLI arguments. """ batches: list[dict[str, Any]] = [] if not batch_json_list: return [{}] for batch_json in batch_json_list: if not batch_json.strip(): batches.append({}) continue # Load JSON from disk if we were given a valid file path if os.path.isfile(batch_json): try: with open(batch_json, "r") as f: batches.append(json.load(f)) except Exception as e: raise ValueError( f"Failed to load batch JSON from file {batch_json}: {e}" ) from e continue # Otherwise treat the input as JSON try: result = json.loads(batch_json) if isinstance(result, list): batches.extend(result) elif isinstance(result, dict): batches.append(result) else: raise ValueError( "Invalid batch JSON: input must be a dictionary of CLI arguments" ) continue except json.JSONDecodeError as e: raise ValueError(f"Invalid batch JSON {batch_json}: {e}") from e return batches