Source code for czbenchmarks.file_utils

import logging
import os
from datetime import datetime, timedelta
from pathlib import Path

import boto3
import botocore
from botocore.config import Config

from czbenchmarks.constants import DATASETS_CACHE_PATH
from czbenchmarks.exceptions import RemoteStorageError

log = logging.getLogger(__name__)

# Global cache manager instance
DEFAULT_CACHE_DIR = os.getenv("DATASETS_CACHE_PATH", DATASETS_CACHE_PATH)
DEFAULT_CACHE_EXPIRATION_DAYS = int(os.getenv("CZBENCHMARKS_CACHE_EXPIRATION_DAYS", 30))


[docs] class CacheManager: """Centralized cache management for remote files.""" def __init__( self, cache_dir: str | Path = DEFAULT_CACHE_DIR, expiration_days: int = DEFAULT_CACHE_EXPIRATION_DAYS, ): self.cache_dir = Path(cache_dir).expanduser() self.expiration_days = expiration_days self.ensure_directory_exists(self.cache_dir)
[docs] def ensure_directory_exists(self, directory: Path) -> None: """Ensure the given directory exists.""" directory.mkdir(parents=True, exist_ok=True)
[docs] def get_cache_path(self, remote_url: str) -> Path: """Generate a local cache path for a remote file.""" filename = Path(remote_url).name return self.cache_dir / filename
[docs] def is_expired(self, file_path: Path) -> bool: """Check if a cached file is expired.""" if not file_path.exists(): return True last_modified = datetime.fromtimestamp(file_path.stat().st_mtime) return datetime.now() - last_modified > timedelta(days=self.expiration_days)
[docs] def clean_expired_cache(self) -> None: """Clean up expired cache files.""" for file in self.cache_dir.iterdir(): if self.is_expired(file): log.info(f"Removing expired cache file: {file}") file.unlink()
# Default cache manager instance _default_cache_manager = CacheManager() def _get_s3_client(make_unsigned_request: bool = True) -> boto3.client: """Get an S3 client with optional unsigned requests.""" if make_unsigned_request: return boto3.client("s3", config=Config(signature_version=botocore.UNSIGNED)) else: return boto3.client("s3")
[docs] def download_file_from_remote( remote_url: str, cache_dir: str | Path = None, make_unsigned_request: bool = True, ) -> str: """ Download a remote file to a local cache directory. Args: remote_url (str): Remote URL of the file (e.g., S3 path). cache_dir (str | Path, optional): Local directory to save the file. Defaults to the global cache manager's directory. make_unsigned_request (bool, optional): Whether to use unsigned requests for S3 (default: True). Returns: str: Local path to the downloaded file. Raises: ValueError: If the remote URL is invalid. RemoteStorageError: If the file download fails due to S3 errors. Notes: - If the file already exists in the cache and is not expired, it will not be downloaded again. - Unsigned requests are tried first, followed by signed requests if the former fails. """ cache_manager = ( _default_cache_manager if cache_dir is None else CacheManager(cache_dir) ) try: bucket, remote_key = remote_url.removeprefix("s3://").split("/", 1) except ValueError: raise ValueError(f"Invalid remote URL: {remote_url}") local_file = cache_manager.get_cache_path(remote_url) if local_file.exists() and not cache_manager.is_expired(local_file): log.info(f"File already exists in cache: {local_file}") return str(local_file) s3 = _get_s3_client(make_unsigned_request) try: s3.download_file(bucket, remote_key, str(local_file)) except botocore.exceptions.ClientError: if not make_unsigned_request: raise log.warning("Unsigned request failed. Trying signed request.") s3 = _get_s3_client(make_unsigned_request=False) s3.download_file(bucket, remote_key, str(local_file)) except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e: raise RemoteStorageError( f"Failed to download {remote_url} to {local_file}" ) from e log.info(f"Downloaded file to cache: {local_file}") return str(local_file)