Source code for czbenchmarks.cli.types

import argparse
from functools import cached_property
import operator
from typing import Any, Generic, TypeVar

from pydantic import BaseModel, computed_field

from czbenchmarks.datasets import utils as dataset_utils
from czbenchmarks.metrics.types import AggregatedMetricResult, MetricResult
from czbenchmarks.models.types import ModelType
from czbenchmarks.models import utils as model_utils
from czbenchmarks.tasks.base import BaseTask


TaskType = TypeVar("TaskType", bound=BaseTask)
ModelArgsDict = dict[str, str | int]  # Arguments passed to model inference
RuntimeMetricsDict = dict[
    str, str | int | float
]  # runtime metrics like elapsed time or CPU count, not implemented yet


[docs] class ModelArgs(BaseModel): name: str # Upper-case model name e.g. SCVI args: dict[str, list[str | int]] # Args forwarded to the model container
[docs] class TaskArgs(BaseModel, Generic[TaskType]): model_config = {"arbitrary_types_allowed": True} # Required to support TaskType name: str # Lower-case task name e.g. embedding task: TaskType set_baseline: bool baseline_args: dict[str, Any]
[docs] class DatasetDetail(BaseModel): name: str organism: str @cached_property def _display_info(self) -> tuple[str, str]: return dataset_utils.dataset_to_display(self.name) @computed_field @property def name_display(self) -> str: return self._display_info[0] @computed_field @property def subset_display(self) -> str: return self._display_info[1]
[docs] class ModelDetail(BaseModel): type: ModelType args: ModelArgsDict @cached_property def _display_info(self) -> tuple[str, str]: return model_utils.model_to_display(self.type, self.args) @computed_field @property def name_display(self) -> str: return self._display_info[0] @computed_field @property def variant_display(self) -> str: return self._display_info[1]
[docs] class TaskResult(BaseModel): task_name: str task_name_display: str model: ModelDetail datasets: list[DatasetDetail] metrics: list[MetricResult | AggregatedMetricResult] runtime_metrics: RuntimeMetricsDict = {} # not implementing any of these for now @property def aggregation_key(self) -> str: """return a key based on the task, model, and datasets in order to aggregate the same results together""" datasets = ",".join( (ds.name for ds in sorted(self.datasets, key=operator.attrgetter("name"))) ) model_args = "_".join( (f"{key}-{value!s}" for key, value in sorted(self.model.args.items())) ) return f"{self.task_name}|{self.model.type}({model_args})|{datasets}"
[docs] class CacheOptions(BaseModel): download_embeddings: bool upload_embeddings: bool upload_results: bool remote_cache_url: str
[docs] @classmethod def from_args(cls, args: argparse.Namespace) -> "CacheOptions": remote_cache_url = args.remote_cache_url or "" return cls( remote_cache_url=remote_cache_url, download_embeddings=bool(remote_cache_url) and args.remote_cache_download_embeddings, upload_embeddings=bool(remote_cache_url) and args.remote_cache_upload_embeddings, upload_results=bool(remote_cache_url) and args.remote_cache_upload_results, )