{ "cells": [ { "cell_type": "markdown", "id": "a6b9d9f4", "metadata": {}, "source": [ "## Model Developer Workflow Example\n", "\n", "### Using czbenchmarks to evaluate an in-development model on a single task\n", "\n", "This notebook demonstrates how to leverage the czbenchmarks library to generate benchmark metrics for a previously published model (using existing model weights) and integrate czbenchmarks into the developer workflow for iterative evaluation of model performance during parameter tuning.\n", "\n", "We focus on the `Cell Clustering` task as an example, but the same approach can be applied to other tasks supported by `czbenchmarks`. For a comprehensive overview of all tasks, refer to the `scvi_all_task_sbenchmark.ipynb` notebook.\n", "\n", "In this example, we use scVI, a popular tool for single-cell analysis, to compare the performance of published model weights against variations obtained through iterative re-training. \n", "\n", "> **NOTE**: This workflow can also be adapted for the development and evaluation of entirely new models.\n", "\n", "### Key Highlights:\n", "- **Benchmark Metrics**: Evaluate clustering performance using Adjusted Rand Index (ARI) and Normalized Mutual Information (NMI).\n", "- **Iterative Development**: Demonstrate how to fine-tune model parameters and assess performance improvements.\n", "- **Generalization**: Showcase how the workflow can be extended to other tasks and models.\n", "\n", "### Step 1: Setup and Imports\n", "\n", "To begin, ensure your environment is properly configured. This includes setting up a virtual environment, installing required dependencies, and registering the environment as a Jupyter kernel. Below is the setup process:\n", "\n", "#### Virtual Environment Setup (Optional)\n", "If you need to create a new virtual environment, uncomment and run the commands in cell below\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "0b596f76", "metadata": { "vscode": { "languageId": "shellscript" } }, "outputs": [], "source": [ "# # Create isolated virtual environment for scVI and czbenchmarks (run once)\n", "\n", "# !python3 -m venv .venv_scvi\n", "\n", "# # Install model required packages\n", "# !.venv_scvi/bin/python -m pip install --upgrade pip\n", "# !.venv_scvi/bin/python -m pip install ipykernel numpy pandas scvi-tools matplotlib seaborn\n", "\n", "# # Register the new environment as a Jupyter kernel (if not already registered)\n", "# !.venv_scvi/bin/python -m ipykernel install --user --name venv_scvi --display-name \"Python (.venv_scvi)\"\n", "\n", "# print(\"Virtual environment '.venv_scvi' created, dependencies installed, and kernel registered.\")\n" ] }, { "cell_type": "markdown", "id": "5d093bce", "metadata": {}, "source": [ "#### Import Libraries\n", "\n", "This notebook requires the following libraries:\n", "\n", "- **czbenchmarks**: For dataset loading and task evaluation.\n", "- **scVI**: For model inference and fine-tuning.\n", "- **Visualization tools**: For plotting benchmark results (matplotlib, seaborn, pandas)." ] }, { "cell_type": "code", "execution_count": 2, "id": "1f6678a2", "metadata": {}, "outputs": [], "source": [ "from czbenchmarks.datasets import load_dataset\n", "from czbenchmarks.datasets.single_cell_labeled import SingleCellLabeledDataset\n", "from czbenchmarks.tasks import ClusteringTask\n", "from czbenchmarks.tasks.clustering import ClusteringTaskInput\n", "\n", "import scvi\n", "import functools\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import seaborn as sns\n", "import warnings\n", "\n", "warnings.simplefilter(\"ignore\")\n", "sns.set_theme(style=\"whitegrid\")" ] }, { "cell_type": "markdown", "id": "5f9bcfda", "metadata": {}, "source": [ "### Step 2: Load and Prepare the Dataset\n", "\n", "\n", "In this step, we load the pre-configured `tsv2_prostate` dataset, which is specifically designed for single-cell analysis tasks. The `czbenchmarks` library simplifies this process by automatically handling dataset download, caching, and loading as a `SingleCellLabeledDataset`.\n", "\n", "### Key Features:\n", "- **Gene Expression Data (`dataset.adata`)**: AnnData object with gene expression matrices and metadata.\n", "- **Cell Type Labels (`dataset.labels`)**: A pandas Series containing cell type annotations, which serve as ground truth for benchmarking tasks like clustering and label prediction.\n", "\n", "---\n", "\n", "> **NOTE**: Always verify that dataset is compatible with the model input requirements before running inference." ] }, { "cell_type": "code", "execution_count": 3, "id": "20bb6d6e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:czbenchmarks.file_utils:File already exists in cache: /Users/sgupta/.cz-benchmarks/datasets/homo_sapiens_10df7690-6d10-4029-a47e-0f071bb2df83_Prostate_v2_curated.h5ad\n", "INFO:czbenchmarks.datasets.single_cell:Loading dataset from /Users/sgupta/.cz-benchmarks/datasets/homo_sapiens_10df7690-6d10-4029-a47e-0f071bb2df83_Prostate_v2_curated.h5ad in memory mode.\n" ] }, { "data": { "text/plain": [ "TSP25_Prostate_NA_10X_1_1_AAACCCAAGTGGTTAA endothelial cell\n", "TSP25_Prostate_NA_10X_1_1_AAACCCACATGCACTA luminal cell of prostate epithelium\n", "TSP25_Prostate_NA_10X_1_1_AAACGAAGTTCTGACA endothelial cell\n", "TSP25_Prostate_NA_10X_1_1_AAACGCTTCTACCCAC erythrocyte\n", "TSP25_Prostate_NA_10X_1_1_AAAGAACCAGTTGTCA smooth muscle cell\n", " ... \n", "TSP25_Prostate_NA_10X_1_2_TTTATGCTCTTGGTCC fibroblast\n", "TSP25_Prostate_NA_10X_1_2_TTTCACAAGATCGGTG basal cell of prostate epithelium\n", "TSP25_Prostate_NA_10X_1_2_TTTCACAGTGCCTTCT fibroblast\n", "TSP25_Prostate_NA_10X_1_2_TTTCATGCAATAGTAG CD8-positive, alpha-beta T cell\n", "TSP25_Prostate_NA_10X_1_2_TTTCCTCAGGTGATCG fibroblast\n", "Name: cell_type, Length: 2044, dtype: category\n", "Categories (14, object): ['fibroblast', 'T cell', 'mast cell', 'endothelial cell', ..., 'neutrophil', 'mature NK T cell', 'luminal cell of prostate epithelium', 'basal cell of prostate epithelium']" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset: SingleCellLabeledDataset = load_dataset(\"tsv2_prostate\")\n", "dataset.adata\n", "dataset.labels" ] }, { "cell_type": "markdown", "id": "b4df4868", "metadata": {}, "source": [ "#### Optionally Transform Data\n", "\n", "After loading the dataset, you may need to transform the data to meet the requirements of your model." ] }, { "cell_type": "code", "execution_count": 4, "id": "6ab4ce47", "metadata": {}, "outputs": [], "source": [ "# Prepare the dataset for scVI model\n", "adata = dataset.adata.copy()\n", "required_obs_keys = [\"dataset_id\", \"assay\", \"suspension_type\", \"donor_id\"]\n", "adata.obs[\"batch\"] = functools.reduce(\n", " lambda a, b: a + b, [adata.obs[c].astype(str) for c in required_obs_keys]\n", ")" ] }, { "cell_type": "markdown", "id": "78a14fb9", "metadata": {}, "source": [ "### Step 3: Obtain Pre-trained Model, Run Model Inference, and Generate Output\n", "\n", "In this step, we leverage the pre-trained scVI model to generate cell embeddings for evaluation within the benchmarking framework. The pre-trained model weights serve as a reference point for comparison against fine-tuned or newly developed model variants.\n", "\n", "---\n", "\n", "> **NOTE**: For your own model, adapt the loading and inference steps to match your model's requirements. " ] }, { "cell_type": "markdown", "id": "7a5a184b", "metadata": {}, "source": [ "#### Load Pre-trained Model Weights" ] }, { "cell_type": "code", "execution_count": 5, "id": "28b579ad", "metadata": {}, "outputs": [], "source": [ "import os\n", "import boto3\n", "\n", "\n", "def download_scvi_weights(local_model_dir=\"czbenchmarks_scvi_model\"):\n", " if not os.path.exists(local_model_dir):\n", " os.makedirs(local_model_dir, exist_ok=True)\n", " if not os.listdir(local_model_dir):\n", " s3 = boto3.client(\"s3\")\n", " bucket_name = \"cz-benchmarks-data\"\n", " prefix = \"models/v1/scvi_2023_12_15/homo_sapiens/\"\n", " print(\"Downloading model weights from S3...\")\n", " paginator = s3.get_paginator(\"list_objects_v2\")\n", " for page in paginator.paginate(Bucket=bucket_name, Prefix=prefix):\n", " for obj in page.get(\"Contents\", []):\n", " key = obj[\"Key\"]\n", " if key.endswith(\"/\"):\n", " continue\n", " local_path = os.path.join(local_model_dir, os.path.relpath(key, prefix))\n", " os.makedirs(os.path.dirname(local_path), exist_ok=True)\n", " s3.download_file(bucket_name, key, local_path)\n", " print(f\"Downloaded model weights to {local_model_dir}\\n\")\n", " return local_model_dir\n", "\n", "\n", "model_weights_dir = download_scvi_weights()" ] }, { "cell_type": "markdown", "id": "a12a5b8d", "metadata": {}, "source": [ "#### Generate Embeddings with Pre-trained Model\n", "\n", "Extract cell embeddings (latent representations) from the model, which will be used for downstream benchmarking tasks." ] }, { "cell_type": "code", "execution_count": 6, "id": "1aed62da", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[34mINFO \u001b[0m File czbenchmarks_scvi_model/model.pt already downloaded \n", "\u001b[34mINFO \u001b[0m Found \u001b[1;36m44.05\u001b[0m% reference vars in query data. \n", "\u001b[34mINFO \u001b[0m File czbenchmarks_scvi_model/model.pt already downloaded \n", "Generated scVI embedding with shape: (2044, 50)\n" ] } ], "source": [ "scvi.model.SCVI.prepare_query_anndata(adata, model_weights_dir)\n", "scvi_model = scvi.model.SCVI.load_query_data(adata, model_weights_dir)\n", "scvi_model.is_trained = True\n", "scvi_output_embedding = scvi_model.get_latent_representation()\n", "print(f\"Generated scVI embedding with shape: {scvi_output_embedding.shape}\")" ] }, { "cell_type": "markdown", "id": "33a034dc-a4f1-417b-bb5b-57cae6072a71", "metadata": {}, "source": [ "### Step 4: Clustering Task Evaluation\n", "\n", "Evaluate the quality of embeddings generated by pre-trained and fine-tuned scVI models using the `ClusteringTask` from `czbenchmarks`. Metrics include:\n", "\n", "- **Adjusted Rand Index (ARI)**: Measures similarity between predicted clusters and ground truth labels.\n", "- **Normalized Mutual Information (NMI)**: Quantifies mutual dependence between predicted clusters and ground truth labels.\n", "\n", "Higher ARI and NMI scores indicate better clustering performance." ] }, { "cell_type": "code", "execution_count": 7, "id": "5c838888", "metadata": {}, "outputs": [], "source": [ "# Utility function to run clustering task\n", "def run_clustering_benchmark(cell_representation, obs, labels):\n", " task = ClusteringTask()\n", " task_input = ClusteringTaskInput(obs=obs, input_labels=labels)\n", " results = task.run(cell_representation=cell_representation, task_input=task_input)\n", " return {r.metric_type: r.value for r in results}" ] }, { "cell_type": "markdown", "id": "fbad691c", "metadata": {}, "source": [ "#### Pre-trained Model Evaluation\n", "\n", "Benchmark the clustering performance of the embeddings generated by the pre-trained scVI model." ] }, { "cell_type": "code", "execution_count": 8, "id": "17a912db-2c31-480e-83c7-eebb2f2dd77e", "metadata": {}, "outputs": [], "source": [ "benchmark_results = {}\n", "\n", "# Pre-trained scVI\n", "benchmark_results[\"scvi_pretrained\"] = run_clustering_benchmark(\n", " scvi_output_embedding, dataset.adata.obs, dataset.labels\n", ")" ] }, { "cell_type": "markdown", "id": "be13cbf0", "metadata": {}, "source": [ "#### Baseline Comparison\n", "\n", "Compute clustering metrics for a baseline method (e.g., PCA) to establish a reference point for comparison." ] }, { "cell_type": "code", "execution_count": 9, "id": "6c39fb3b", "metadata": {}, "outputs": [], "source": [ "baseline_embedding = ClusteringTask().compute_baseline(dataset.adata.X)\n", "benchmark_results[\"pca_baseline\"] = run_clustering_benchmark(\n", " baseline_embedding, dataset.adata.obs, dataset.labels\n", ")" ] }, { "cell_type": "markdown", "id": "f8b0e3d6", "metadata": {}, "source": [ "#### Iterative Model Evaluation\n", "\n", "Extend the workflow to evaluate clustering performance for fine-tuned or newly developed model variants." ] }, { "cell_type": "code", "execution_count": 10, "id": "653e0095", "metadata": {}, "outputs": [], "source": [ "# Function to fine-tune the scVI model\n", "def fine_tune_scvi_model(adata, n_latent=10, max_epochs=10):\n", " model = scvi.model.SCVI(adata, n_latent=n_latent)\n", " model.train(\n", " max_epochs=max_epochs,\n", " plan_kwargs={\"lr\": 0.0005},\n", " early_stopping=True,\n", " early_stopping_patience=10,\n", " )\n", " return model" ] }, { "cell_type": "markdown", "id": "14e7bf27-d67b-4cd8-af3a-7095766827a9", "metadata": {}, "source": [ "**Demonstrating Iterative Model Evaluation and Hyperparameter Tuning with czbenchmarks**\n", "\n", "In this section, we demonstrate how `czbenchmarks` supports iterative model evaluation and hyperparameter tuning within a development workflow. By leveraging the `fine_tune_model` function, we simulate the process of updating model parameters—such as varying the latent dimension (`n_latent`)—and assess the impact on clustering performance.\n", "\n", "**Workflow Overview:**\n", "1. **Generate Model Variants:** Create multiple scVI model versions by adjusting key parameters (e.g., `n_latent`).\n", "2. **Fine-Tune Each Variant:** Train each model variant using the scVI training API with custom configurations.\n", "3. **Evaluate Performance:** Apply `czbenchmarks` to compute clustering metrics (ARI, NMI) for each variant, enabling direct comparison and informed parameter selection.\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "76cda22c-0354-4120-ac6c-21c712d502e0", "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO: GPU available: True (mps), used: False\n", "INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (mps), used: False\n", "INFO: TPU available: False, using: 0 TPU cores\n", "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", "INFO: HPU available: False, using: 0 HPUs\n", "INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7e886df2a13a45e0be1cfe933448a23c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Training: 0%| | 0/10 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
scvi_pretrainedpca_baselinescvi_n_latent_3scvi_n_latent_5scvi_n_latent_10scvi_n_latent_25
MetricType.ADJUSTED_RAND_INDEX0.719370.6421490.5066750.5287590.5183800.507529
MetricType.NORMALIZED_MUTUAL_INFO0.866540.8331380.7342230.7557200.7321440.740127
\n", "" ], "text/plain": [ " scvi_pretrained pca_baseline \\\n", "MetricType.ADJUSTED_RAND_INDEX 0.71937 0.642149 \n", "MetricType.NORMALIZED_MUTUAL_INFO 0.86654 0.833138 \n", "\n", " scvi_n_latent_3 scvi_n_latent_5 \\\n", "MetricType.ADJUSTED_RAND_INDEX 0.506675 0.528759 \n", "MetricType.NORMALIZED_MUTUAL_INFO 0.734223 0.755720 \n", "\n", " scvi_n_latent_10 scvi_n_latent_25 \n", "MetricType.ADJUSTED_RAND_INDEX 0.518380 0.507529 \n", "MetricType.NORMALIZED_MUTUAL_INFO 0.732144 0.740127 " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Prepare and display results\n", "results_df = pd.DataFrame(benchmark_results)\n", "display(results_df)" ] }, { "cell_type": "markdown", "id": "7e36e4f2", "metadata": {}, "source": [ "#### Visualizing Benchmark Results Across Model Variants\n", "\n", "The bar plot below shows the Adjusted Rand Index (ARI) and Normalized Mutual Information (NMI) scores for each model variant (including the pre-trained scVI, fine-tuned variants, and PCA baseline). Higher scores indicate better clustering performance." ] }, { "cell_type": "code", "execution_count": 13, "id": "f421343f-a140-49b1-b228-6571d30a49fc", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot ARI and NMI for each model variant\n", "results_df.plot.bar(figsize=(10, 5))\n", "plt.title(\"Clustering Performance: scVI Model Variants vs. PCA Baseline\", fontsize=14)\n", "plt.ylabel(\"Score\", fontsize=10)\n", "plt.xlabel(\"Model Variant\", fontsize=10)\n", "plt.xticks(rotation=0)\n", "plt.legend(title=\"Metric\")\n", "plt.tight_layout()\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": ".venv_notebooks", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.11" } }, "nbformat": 4, "nbformat_minor": 5 }