Skip to content

Training Octopi Models

This guide explains how to train 3D U-Net–based models with Octopi, covering both single-model training and automated model exploration.


Training Modes

Octopi supports two complementary workflows:

  • Single Model Training – Train or fine-tune a specific architecture when you already know what you want.

  • Model Exploration – Automatically search for strong architectures and hyperparameters using Bayesian optimization (recommended for new applications).

Dataset Splitting

For both single-model training and model exploration, Octopi supports training on data drawn from multiple CoPick projects by providing multiple --config files.

You may explicitly control which runs are used for training and validation by specifying --trainRunIDs and --validateRunIDs.

If neither of these options is provided, Octopi automatically:

  1. Collects all runs that contain both the requested tomograms and the specified segmentation targets
  2. Splits the data into training and validation sets according to the --data-split ratio

Single Model Training

For specific use cases or when you have a known good architecture, you can train a single model directly. In this case, the command only allows for training U-Net models. To play around with more unique model configurations, or to try importing new model designs refer to the API or octopi model-explore.

Training New Models

This command initializes a new U-Net model using the specified architecture and training parameters.

Bash
octopi train \
    --config config.json \
    --voxel-size 10 --tomo-alg wbp \
    --tomo-batch-size 50 --val-interval 10 \
    --target-info targets,octopi,1

Fine Tuning Models

If we have base weights that we would like to fine-tune for new datasets, we can still use the train command. Instead of specifying the model architecture, we can simply point to the configuration file and weights to load the existing model to fine tune.

Bash
octopi train \
    --config config.json \
    --voxel-size 10 --tomo-alg wbp \
    --model-config results/model_config.yaml \
    --model-weights results/best_model_weights.pth
octopi train parameters
Parameter Description Example
--config One or more CoPick configuration files. Multiple entries may be provided as session_name,path. config.json
--voxel-size Voxel size (Γ…) of tomograms used for training. Must match the target segmentations. 10
--target-info Target specification in the form name or name,user_id,session_id. targets,octopi,1
--tomo-alg Tomogram reconstruction algorithm(s). Multiple values may be comma-separated. wbp or denoised,wbp
--trainRunIDs Explicit list of run IDs to use for training (overrides automatic splitting). run1,run2
--validateRunIDs Explicit list of run IDs to use for validation. run3,run4
--data-split Train/validation(/test) split. Single value β†’ train/val, two values β†’ train/val/test. 0.8 or 0.7,0.1
--output Directory where model checkpoints, logs, and configs are written. results
Parameter Description Default
--num-epochs Total number of training epochs. 1000
--val-interval Frequency (in epochs) for computing validation metrics. 10
--batch-size Number of cropped 3D patches processed per training step. 16
--lr Learning rate for the optimizer. 0.001
--best-metric Metric used to determine the best checkpoint. Supports fBetaN. avg_f1
--ncache-tomos Number of tomograms cached per epoch (SmartCache window size). 15
--background-ratio Foreground/background crop sampling ratio. 0.0
--tversky-alpha Alpha parameter for the Tversky loss (foreground weighting). 0.3

Choosing --ncache-tomos

Use this parameter when your dataset has more tomograms than can fit into memory at once β€” only the cached subset is loaded per epoch, keeping memory usage bounded. Values between 8 and 32 are recommended. Higher values expose the model to more diversity per epoch and improve training throughput, but require more RAM.

Parameter Description Default
--channels Feature map sizes at each UNet level. 32,64,96,96
--strides Downsampling strides between UNet levels. 2,2,1
--res-units Number of residual units per UNet level. 1
--dim-in Input patch size in voxels (cube). 96

These options are used to continue training from an existing model.

Parameter Description Example
--model-config Model configuration generated by a previous training run. results/model_config.yaml
--model-weights Pre-trained model weights used for initialization. results/best_model_weights.pth

Training Output

During training, you'll see:

  • Progress indicators: Real-time loss and accuracy metrics
  • Validation results: Periodic evaluation on validation set
  • Model checkpoints: Saved to results/ directory by default
  • Training logs: Detailed logs for monitoring and debugging

Model Exploration

Why Start with Model Exploration?

Rather than manually guessing which learning rates, batch sizes, or architectural choices work best for your specific tomograms, model exploration systematically tests combinations and learns from each trial to make better choices. This automated approach consistently finds better models than manual tuning.

Bayesian Optimization Workflow OCTOPI's automated architecture search uses Bayesian optimization to efficiently explore hyperparameters and find optimal configurations for your specific data.

Model exploration is recommended because:

  • βœ… No expertise required - Automatically finds the best model for your data
  • βœ… Efficient search - Optimal performance tailored to your specific dataset
  • βœ… Time savings - Avoids trial-and-error experimentation

Quick Start

Bash
octopi model-explore \
    --config config.json \
    --target-info targets,octopi,1 \
    --voxel-size 10 --tomo-alg denoised \
    --data-split 0.7 --model-type Unet \
    --num-trials 100 --best-metric fBeta3 \
    --study-name my-explore-job

This automatically saves results to a timestamped directory and runs 100 optimization trials by default.

octopi model-explore parameters
Parameter Description Default Notes
--config One or more CoPick config paths. Multiple entries may be provided as session_name,path. – Use multiple --config entries to combine sessions
--voxel-size Voxel size (Γ…) of tomograms used. 10 Must match target segmentations
--target-info Target specification: name or name,user_id,session_id. targets,octopi,1 From the label preparation step
--tomo-alg Tomogram reconstruction algorithm(s). Comma-separated values enable multi-alg training. wbp Example: denoised,wbp
--trainRunIDs Explicit list of run IDs to use for training (overrides automatic splitting). – Example: run1,run2
--validateRunIDs Explicit list of run IDs to use for validation. – Example: run3,run4
--data-split Train/val(/test) split. Single value β†’ train/val, two values β†’ train/val/test. 0.8 Example: 0.7,0.1 β†’ 70/10/20
--output Name/path of the output directory. explore_results Results are written here per study
--study-name Name of the Optuna/MLflow experiment. model-search Useful for organizing runs
Parameter Description Default Notes
--model-type Model family used for exploration. Unet Options: unet, attentionunet, mednext, segresnet
--num-epochs Number of epochs per trial. 1000 Consider fewer epochs for quick sweeps
--val-interval Validation frequency (every N epochs). 10 Smaller = more frequent metrics
--ncache-tomos Number of tomograms cached per epoch (SmartCache window size). 15 Higher values improve throughput but require more memory
--best-metric Metric used to select the best checkpoint (supports fBetaN). avg_f1 Example: fBeta3 emphasizes recall
--background-ratio Foreground/background crop sampling ratio. 0.0 1.0 β†’ 50/50; <1.0 biases toward foreground

Choosing --ncache-tomos

Use this parameter when your dataset has more tomograms than can fit into memory at once β€” only the cached subset is loaded per epoch, keeping memory usage bounded. Values between 8 and 32 are recommended. Higher values expose the model to more diversity per epoch and improve training throughput, but require more RAM.

Parameter Description Default Notes
--num-trials Number of Optuna trials (models) to evaluate. 100 Use 50-200 for sufficient exploration of the parameter landscape.
--random-seed Random seed for reproducibility. 42 Fix this when comparing changes
Parameter Description Default Notes
--submitit Submit trials as independent SLURM jobs using submitit instead of running locally. False Enables HPC / multi-node execution
--njobs Maximum number of concurrent SLURM jobs (trials) to run at once. 5 Each job runs exactly one Optuna trial
--compute-constraint CPU and memory request per SLURM job in the form cpus,mem_gb. 4,16 Example: 8,32 requests 8 CPUs and 32 GB RAM
--timeout Walltime limit (hours) per SLURM job. 4 Jobs exceeding this limit are terminated by the scheduler

What changes when --submitit is enabled?

By default, octopi model-explore runs locally, launching one worker per available GPU and executing multiple trials within a single process.

When --submitit is enabled, Octopi switches to a job-based execution model designed for SLURM-HPC clusters:

  • Each Optuna trial is executed as a separate SLURM job
  • Jobs may run on different nodes and start at different times
  • Resource limits (CPUs, memory, walltime) are enforced per trial
  • Failed jobs affect only the corresponding trial, not the entire study

This mode provides better scalability and fault isolation for large model exploration runs, especially on shared HPC systems.

Because multiple jobs may write to the same Optuna and MLflow tracking databases, Octopi uses best-effort retries for database operations to safely handle temporary contention.

What Gets Optimized?

Model exploration uses fixed architectures with two available options:

  • Unet - Standard 3D U-Net (default, recommended for most cases)
  • AttentionUnet - U-Net with attention mechanisms (for complex data)

For each architecture, it optimizes:

  • Hyperparameters - Learning rate, batch size, loss function parameters
  • Architecture details - Channel sizes, stride configurations, residual units
  • Training strategies - Regularization and data augmentation
Parallelism and GPU Utilization

When running octopi model-explore, Octopi automatically detects the available GPU resources and spawns one worker per GPU. Each worker independently trains a candidate model, allowing multiple Optuna trials to run concurrently.

This means:

  • On a machine with N GPUs, up to N models are trained in parallel
  • Trial scheduling is handled automatically by Optuna
  • GPU utilization scales naturally from a single workstation to multi-GPU HPC nodes

On shared HPC systems, the number of concurrent trials is therefore determined by the number of GPUs allocated to your job (e.g., via Slurm or another scheduler).

Note

If fewer GPUs are available than the total number of trials (--num-trials), remaining trials are queued and executed as workers become free.

Monitoring Your Training

Track the progress of model exploration runs in real time, inspect trial performance, and understand which hyperparameters and architectural choices drive model quality.

Octopi integrates with Optuna to provide a high-level view of the architecture and hyperparameter search process. This dashboard is best suited for understanding which trials perform best, which parameters matter most, and how the optimization converges over time.

Setup Options:

  • Web Dashboard – Launch the Optuna dashboard directly from the command line:

    Bash
    optuna-dashboard sqlite:///{path-to}/trials.db
    
    Replace path/to/trials.db with the path to the Optuna study database located in the model exploration output directory.

  • VS Code Extension - Install Optuna extension for integrated monitoring, right click on trials.db in the file navigator to launch the dashboard.

What you'll see:

Optuna

  • Trial progress and current best performance
  • Parameter importance (which settings matter most)
  • Optimization history and convergence trends

MLflow complements Optuna by providing detailed, per-trial training information, including loss curves, validation metrics, model checkpoints, and configuration artifacts.

If you are running octopi model-explore on your local machine, start the MLflow UI in the same environment:

Bash
mlflow ui

If the mlflow command is not directly available, use the alternative:

Bash
python -m mlflow ui

This will print output similar to:

Text Only
INFO:     Uvicorn running on http://127.0.0.1:5000

Copy the URL and paste it into your browser to open the MLflow dashboard.


When octopi is executed on a remote HPC system, the MLflow UI runs on a machine that is not directly accessible from your local browser. To view it locally, you must forward the MLflow port using an SSH tunnel.

On your local machine, open a terminal and run:

Bash
ssh -L 5000:localhost:5000 username@login-node-hostname

This creates an SSH tunnel to the HPC login node and forwards port 5000 to your local machine.

Once logged in, navigate to the directory where the model exploration results are being written. You should see files such as:

Text Only
mlflow.db
mlruns/

These files define the MLflow experiment state. From this directory, start the MLflow UI:

Bash
python -m mlflow ui --host 0.0.0.0 --port 5000

Finally, on your local machine, open the following address in your browser:

Text Only
http://127.0.0.1:5000

When the dashboard is opened, you should now see the MLflow experiment associated with your study-name specified when running the job.

MLFlow Dashboard

Model Exploration Output

  • Optuna study database
  • MLflow experiment logs
  • Best-performing model checkpoints
  • Per-trial metrics and configurations

Class Weighting

When training on datasets with multiple particle types, some classes may be more important to detect than others β€” or should be excluded from scoring entirely. Octopi supports per-class weights read directly from the CoPick configuration file, so no additional training flags are needed.

Weights are applied in two places:

  • Training β€” the validation F-beta score is computed as a weighted average across classes, so checkpoint selection favors performance on higher-weight classes.
  • Evaluation β€” octopi evaluate uses the same weights when reporting aggregate metrics.
How to set weights in your CoPick config

Add a metadata field to any pickable_object entry in your config.json with a weight key:

JSON
{
    "name": "ribosome",
    "is_particle": true,
    "label": 4,
    "radius": 150,
    "metadata": {
        "weight": 2
    }
}

If metadata or weight is absent for a class, it defaults to 1. Set weight to 0 to exclude a class from scoring entirely.

The following config weights beta-galactosidase and thyroglobulin twice as heavily as apoferritin and ribosome, and excludes beta-amylase from scoring:

JSON
{
    "pickable_objects": [
        {
            "name": "apoferritin",
            "is_particle": true,
            "label": 1,
            "radius": 60,
            "metadata": { "weight": 1 }
        },
        {
            "name": "beta-amylase",
            "is_particle": true,
            "label": 2,
            "radius": 80,
            "metadata": { "weight": 0 }
        },
        {
            "name": "beta-galactosidase",
            "is_particle": true,
            "label": 3,
            "radius": 90,
            "metadata": { "weight": 2 }
        },
        {
            "name": "ribosome",
            "is_particle": true,
            "label": 4,
            "radius": 150,
            "metadata": { "weight": 1 }
        },
        {
            "name": "thyroglobulin",
            "is_particle": true,
            "label": 5,
            "radius": 130,
            "metadata": { "weight": 2 }
        }
    ]
}

Tip

Weights only affect the validation metric used for checkpoint selection and the aggregate evaluation score β€” the loss function and per-class metrics are unaffected.


Compute & Performance

This section covers practical recommendations for allocating compute, diagnosing bottlenecks, and recovering from out-of-memory (OOM) errors. If you're just getting started, the defaults work on most hardware β€” come back here if you need to tune throughput or hit an OOM.

Requesting resources

Octopi auto-scales at runtime: it reads os.sched_getaffinity(0) to match SLURM's --cpus-per-task, detects GPU VRAM and compute capability for mixed precision, and picks between CacheDataset and SmartCacheDataset based on --ncache-tomos. The main lever is what you ask SLURM for.

Bash
#SBATCH --cpus-per-task=16
#SBATCH --mem-per-cpu=12G
#SBATCH --gpus=1
#SBATCH --constraint=[h100|h200|a100]

16 CPUs saturates the DataLoader worker pool (cap is 16). 12 GB/CPU β†’ 192 GB total leaves comfortable headroom for worker copy-on-write drift on large caches. Drop to 8G if your cluster is tight and you cache fewer than ~15 tomograms.

Local workstation with RTX 3090/4090, A5000, etc.:

  • 8–12 CPUs is enough worker throughput for most datasets.
  • 64–96 GB system RAM β€” budget ~1–2 GB per worker plus ~0.5–2 GB per cached tomogram.

Limited to 8–16 GB VRAM, 16–32 GB system RAM:

  • Lower --ncache-tomos so the training cache fits in RAM.
  • Lower --batch-size from 16 β†’ 8 to fit the GPU.
  • Drop --dim-in from 96 β†’ 64 for another ~3.4Γ— activation-memory reduction.
GPU tier Examples --batch-size Mixed precision (auto) System RAM
≀ 12 GB RTX 4070, 3060 12 GB, T4 8 fp16 + GradScaler 16–32 GB
16 GB RTX A4000, 4060 Ti 16 GB, V100, L4 16 fp16 / bf16 32–64 GB
24 GB RTX 4090, 3090, A5000 16–32 bf16 64–128 GB
40–48 GB A100 40 GB, A6000, L40 32–48 bf16 128 GB
β‰₯ 80 GB A100 80 GB, H100, H200 48–64 bf16 128–192 GB

Octopi selects mixed precision automatically:

  • bf16 on Ampere and newer (RTX 30xx/40xx, A100, H100, H200) β€” no gradient scaler needed.
  • fp16 + GradScaler on Volta/Turing (V100, T4, RTX 20xx).
  • fp32 on Pascal and older, or on CPU.

Where memory goes

System RAM breakdown

Component Rough cost Notes
Main process (Python + PyTorch + CUDA) ~3 GB Fixed
Cached training tomograms ~500 MB – 2 GB each Depends on volume size and dtype
Cached validation tomograms ~500 MB – 2 GB each Usually 2–6 val tomograms
Each DataLoader training worker ~1–2 GB overhead Γ— num_workers
Worker copy-on-write drift up to cache size per worker Accumulates during an epoch, resets when the worker dies
Validation prefetch ~1–2 GB Γ— val_workers Γ— 2 Only during validation
Matplotlib training-curve figure ~0.5–2 GB Grows slowly across validations

Worst-case RAM peak is usually right after the first validation, when validation-only allocations are still resident and training workers re-fork for the next epoch. Size the cgroup for this peak.

GPU VRAM breakdown (default UNet at 96Β³)

Component Rough cost
CUDA context ~0.5 GB
Model weights 0.05–0.2 GB
Optimizer state (AdamW, fp32) 2Γ— model weights
Activations (bf16, batch 16) ~3–5 GB
Activations (bf16, batch 32) ~8–10 GB
Validation forward (64 patches at 128Β³) ~2 GB transient

Diagnosing CPU-bound vs GPU-bound

Run nvidia-smi dmon -s u -d 2 -c 20 in a second terminal during a training epoch (skip the first β€” cudnn.benchmark autotune and cache init make it unrepresentative).

sm utilization Interpretation What to try
>85% sustained GPU-bound β€” GPU is the limit Larger --batch-size; future torch.compile or channels_last_3d support
40–70% with regular dips Partially data-bound β€” workers can't keep GPU fed More CPUs, higher --ncache-tomos
<30% Severely data-bound More CPUs, bigger cache, or match your CPU/GPU budget better

Same epoch time on different GPUs?

If training takes the same wall-clock on an A6000 and an H100, you're definitely data-bound β€” the GPU never gets a chance to differentiate itself. This is a clean diagnostic.

Troubleshooting OOMs

Symptoms: Detected N oom_kill events in StepId=... from slurmstepd, or DataLoader worker (pid N) killed by signal: Killed after a few epochs.

In order of preference:

  1. Bump --mem-per-cpu β€” e.g., 8G β†’ 12G on a 16-CPU allocation adds 64 GB. Cheapest fix.
  2. Lower --ncache-tomos β€” each cached tomogram is copy-on-write-shared to every worker; cache size Γ— worker count is the worst-case drift.
  3. Lower val_batch_size β€” the CopickDataModule.create default is 64; drop to 32 if validation OOMs.
  4. Lower the training-worker cap β€” in octopi/datasets/helpers.py, adjust auto_num_workers(cap=...) from 16 β†’ 12 β†’ 8. Costs ~15% on epoch throughput but significantly reduces peak RAM.
  5. Enable glibc trim in your SLURM script:
    Bash
    export MALLOC_TRIM_THRESHOLD_=0
    export MALLOC_ARENA_MAX=2
    
    Forces the C allocator to return freed memory to the OS more aggressively.

Symptoms: CUDA out of memory. Tried to allocate ....

In order of preference:

  1. Lower --batch-size (crops per tomogram) β€” the biggest GPU VRAM lever. Drop 32 β†’ 16 β†’ 8.
  2. Lower --dim-in (crop size) β€” 96 β†’ 64 cuts activation memory ~3.4Γ—.
  3. Lower val_batch_size β€” from 128 β†’ 64 β†’ 32.
  4. Disable EMA (use_ema=False in the API) β€” saves one model-sized shadow copy.
  5. Shrink the model β€” e.g., --channels 32,64,80,80 instead of larger variants.

OOMs that first appear around val_interval

The first validation adds a persistent ~5–10 GB baseline (MONAI transform caches, matplotlib figure, larger CUDA pool) that is not released. If you survive epochs 0..N and crash shortly after the first validation at epoch val_interval, you're pressing against the cgroup limit. Either bump --mem-per-cpu or reduce val_batch_size.

Why memory drifts over epochs

A short explanation of what you're fighting when you hit RAM OOMs after running fine for many epochs:

  • Copy-on-write drift. When PyTorch forks DataLoader workers, each shares the main-process memory via copy-on-write. Reads are "free" β€” except in Python, even reading an object increments a reference count, which writes to memory. So workers slowly accumulate private copies of the cached tomogram pages they touch. Over one epoch each worker can drift by up to the training cache size. Workers dying at the end of each epoch reclaims it (octopi uses persistent_workers=False deliberately for this reason).
  • Python and glibc keep freed memory. Freed allocations go to allocator pools, not back to the kernel. From the cgroup's perspective, RSS is roughly monotonically non-decreasing until peak usage.
  • Validation adds a sticky baseline. First-time validation expands the CUDA caching allocator, imports code paths for the first time, creates a long-lived matplotlib figure, and builds class-index caches. None of this is released.

Together, these mean main-process RAM grows over the first few epochs and then plateaus. If that plateau + per-epoch worker drift β‰ˆ cgroup limit, you OOM. The fix is more RAM, less drift (fewer workers, smaller cache), or less baseline (smaller val_batch_size).


Next Steps

After training is complete:

Run inference - Apply your best model to new tomograms and get particle locations from predictions.