import logging
from typing import List
import hydra
from omegaconf import OmegaConf
from .types import ModelType
from ..utils import initialize_hydra
logger = logging.getLogger(__name__)
[docs]
def list_available_models() -> List[str]:
"""
Lists all available models defined in the models.yaml configuration file.
Returns:
list: A sorted list of model names available in the configuration.
"""
initialize_hydra()
# Load the datasets configuration
cfg = OmegaConf.to_container(hydra.compose(config_name="models"), resolve=True)
# Extract dataset names
model_names = list(cfg.get("models", {}).keys())
# Sort alphabetically for easier reading
model_names.sort()
return model_names
_MODEL_VARIANT_FINETUNE_TO_DISPLAY = { # maps a tuple of (name, variant, fine_tune_dataset) to display (model name, details)
# AIDO
("AIDO", "aido_cell_3m", None): ("AIDO", "Cell-3M"),
# Geneformer
("GENEFORMER", "Geneformer-V2-316M", None): (
"GENEFORMER",
"Geneformer-V2-316M (May 2025)",
),
# scGenePT
(
"SCGENEPT",
"scgenept_ncbi_gpt",
"adamson",
): ("scGenePT_{NCBI}", "fine-tuned, Adamson"),
("SCGENEPT", "scgenept_ncbi_gpt", "norman"): (
"scGenePT_{NCBI}",
"fine-tuned, Norman",
),
(
"SCGENEPT",
"scgenept_ncbi+uniprot_gpt",
"adamson",
): ("scGenePT_{NCBI+UniProt}", "fine-tuned, Adamson"),
(
"SCGENEPT",
"scgenept_ncbi+uniprot_gpt",
"norman",
): ("scGenePT_{NCBI+UniProt}", "fine-tuned, Norman"),
(
"SCGENEPT",
"scgenept_go_f_gpt_concat",
"adamson",
): ("scGenePT_{GO-F}", "fine-tuned, Adamson"),
(
"SCGENEPT",
"scgenept_go_f_gpt_concat",
"norman",
): ("scGenePT_{GO-F}", "fine-tuned, Norman"),
(
"SCGENEPT",
"scgenept_go_p_gpt_concat",
"adamson",
): ("scGenePT_{GO-P}", "fine-tuned, Adamson"),
(
"SCGENEPT",
"scgenept_go_p_gpt_concat",
"norman",
): ("scGenePT_{GO-P}", "fine-tuned, Norman"),
(
"SCGENEPT",
"scgenept_go_c_gpt_concat",
"adamson",
): ("scGenePT_{GO-C}", "fine-tuned, Adamson"),
(
"SCGENEPT",
"scgenept_go_c_gpt_concat",
"norman",
): ("scGenePT_{GO-C}", "fine-tuned, Norman"),
(
"SCGENEPT",
"scgenept_go_all_gpt_concat",
"adamson",
): ("scGenePT_{GO-all}", "fine-tuned, Adamson"),
(
"SCGENEPT",
"scgenept_go_all_gpt_concat",
"norman",
): ("scGenePT_{GO-all}", "fine-tuned, Norman"),
# scGPT
("SCGPT", "human", None): ("scGPT", "whole-human"),
# scVI
("SCVI", "homo_sapiens", None): ("scVI", "Census 2023-12-15, homo sapiens"),
# Transcriptformer
("TRANSCRIPTFORMER", "tf-sapiens", None): ("TRANSCRIPTFORMER", "TF-Sapiens"),
("TRANSCRIPTFORMER", "tf-exemplar", None): ("TRANSCRIPTFORMER", "TF-Exemplar"),
("TRANSCRIPTFORMER", "tf-metazoa", None): ("TRANSCRIPTFORMER", "TF-Metazoa"),
# UCE
("UCE", "4l", None): ("UCE", "4L"),
("UCE", "33l", None): ("UCE", "33L"),
}
[docs]
def model_to_display(
model_type: ModelType, model_args: dict[str, str | int]
) -> tuple[str, str]:
"""try to map the model variant names to more uniform, pretty strings"""
key = (model_type.name, model_args.get("model_variant"), model_args.get("dataset"))
try:
return _MODEL_VARIANT_FINETUNE_TO_DISPLAY[key]
except KeyError:
if model_args:
logger.warning(
"No display name provided for %r. Using title-case.", model_type
)
return (
model_type.name.title(),
", ".join(f"{k}={v!r}" for k, v in model_args.items()),
)