Source code for czbenchmarks.utils

import logging

import hydra
from omegaconf import OmegaConf

logging.getLogger("botocore").setLevel(logging.WARNING)
logging.getLogger("botocore.httpchecksum").setLevel(logging.WARNING)

logger = logging.getLogger(__name__)


[docs] def initialize_hydra(config_path="./conf"): """ Initialize Hydra configuration system. This function sets up Hydra's configuration system using the specified configuration directory. If Hydra is already initialized, it clears the existing instance before reinitializing. Args: config_path (str): Path to the configuration directory. Returns: None """ if hydra.core.global_hydra.GlobalHydra.instance().is_initialized(): hydra.core.global_hydra.GlobalHydra.instance().clear() hydra.initialize( config_path=config_path, version_base=None, )
[docs] def import_class_from_config(config_path: str): """ Import a class based on the `_target_` field in a configuration file. This function reads a configuration file, extracts the `_target_` field, and dynamically imports the specified class. Args: config_path (str): Path to the configuration file. Returns: class_obj: The imported class object. Raises: AttributeError: If the specified class does not exist in the module. ImportError: If the module cannot be imported. """ # Load the configuration logger.info(f"Loading model configuration from {config_path}") cfg = OmegaConf.load(config_path) # Get the target class path target_path = cfg._target_ # Import the class using the target path module_path, class_name = target_path.rsplit(".", 1) module = __import__(module_path, fromlist=[class_name]) class_obj = getattr(module, class_name) logger.info(f"Imported class: {class_obj.__name__}") return class_obj