Training¶
This page covers data preparation and training deep learning models for 3D particle picking using octopi.
Target Creation¶
Before training, you need to create training targets from existing particle annotations. We can explicitly define all the objects we'd like to query with a list of tuples (object_name, user_id, session_id)
defining source annotations. This query can either be for point coordinates for protein coordinates, or continuous segmentations such as for membranes or organelles generated from software such as membrain-seg or saber.
The target creation process uses the following key parameters:
- pick_targets: List of tuples (object_name, user_id, session_id) defining source point annotations
- seg_targets: List of segmentation targets (same format as pick_targets) for continuous structures
- radius_scale: Scale factor for creating spherical targets relative to object radius defined in config
- run_ids: Optional subset of tomograms (None for all available)
💡 Seeing Available Queries Per Copick Project
Remember, we can see all the available runs and queries in every copick project with the following CLI command: For local projects: For training data from the data-portal:Creating Training Targets¶
from octopi.entry_points.run_create_targets import create_sub_train_targets
# Configuration
config = '../config.json'
target_name = 'targets'
target_user_id = 'octopi'
target_session_id = '1'
# Tomogram parameters
voxel_size = 10.012
tomogram_algorithm = 'wbp-denoised-denoiset-ctfdeconv'
radius_scale = 0.7 # Fraction of particle radius for sphere targets
# Define source annotations
pick_targets = [
('ribosome', 'data-portal', '1'),
('virus-like-particle', 'data-portal', '1'),
('apoferritin', 'data-portal', '1')
]
seg_targets = ['membrane', 'membrain-seg', '2'] # Optional segmentation targets
# Create targets
create_sub_train_targets(
config, pick_targets, seg_targets, voxel_size, radius_scale,
tomogram_algorithm, target_name, target_user_id, target_session_id
)
Model Training¶
Once training targets are created, you can train deep learning models for particle segmentation. The training process requires defining data splits, model architecture, loss functions, and training parameters. Key training parameters include:
- target_info: Tuple specifying which training targets to use (name, user_id, session_id)
- trainRunIDs/validateRunIDs: Optional lists defining which tomograms to use for training vs validation. If None, octopi automatically uses all runs that have the desired segmentation target.
- model_config: Dictionary specifying the neural network architecture and parameters. We have a few default models to chose from, including the opportunity to import new model architectures.
- loss_function: Loss function appropriate for segmentation tasks (handles class imbalance)
- use_ema: Whether to use exponential moving average for more stable training
Basic Training Setup¶
from octopi.losses import FocalTverskyLoss
from monai.losses import TverskyLoss
from octopi.workflows import train
# Training configuration
config = 'train_config.json'
target_info = ['targets', 'octopi', '1']
results_folder = 'model_output'
# Data splits
trainRunIDs = ['run_001', 'run_002', 'run_003']
valRunIDs = ['run_004']
# Model architecture
model_config = {
'architecture': 'Unet',
'num_classes': 6, # number of objects + 1 for background
'dim_in': 80,
'strides': [2, 2, 1],
'channels': [48, 64, 80, 80],
'dropout': 0.0,
'num_res_units': 1,
}
# Loss function
loss_function = TverskyLoss(
include_background=True,
to_onehot_y=True,
softmax=True,
alpha=0.3,
beta=0.7
)
# Train the model
train(
config, target_info, 'denoised', 10.012, loss_function,
model_config, trainRunIDs=trainRunIDs, validateRunIDs=valRunIDs
)
💡 Training Function Reference
train(config, target_info, tomo_algorithm, voxel_size, loss_function,
model_config, model_weights = None, trainRunIDs = None, validateRunIDs = None,
model_save_path = 'results', best_metric = 'fBeta2', num_epochs = 1000, use_ema = True)
Trains a segmentation model for particle detection.
Parameters:
config
(str): Path to Copick configuration filetarget_info
(tuple): Target specification(name, user_id, session_id)
tomo_algorithm
(str): Tomogram algorithm identifiervoxel_size
(float): Voxel spacing in Angstromsloss_function
: PyTorch loss functionmodel_config
(dict): Model architecture configurationmodel_weights
(str): Path to pretrained weights (default: None)trainRunIDs
(list): List of training tomogram IDs (default: None - uses all available)validateRunIDs
(list): List of validation tomogram IDs (default: None - uses all available)model_save_path
(str): Directory to save trained model (default: 'results')best_metric
(str): Metric for model selection (default: 'fBeta2')num_epochs
(int): Maximum training epochs (default: 1000)use_ema
(bool): Use exponential moving average (default: True)
Returns:
Training results dictionary with loss and metric histories.
Outputs:
{model_save_path}/best_model.pth
: Best model weights{model_save_path}/model_config.yaml
: Model configuration{model_save_path}/results.json
: Training metrics and results
Cross Validation¶
Cross-validation is essential for robust model evaluation, especially when working with limited tomography data. It helps assess how well your model generalizes to unseen tomograms and prevents overfitting to specific data characteristics. This approach systematically trains multiple models, each time holding out a different tomogram for validation.
# Cross-validation training
def train_cross_validation(split_id):
runIDs = get_all_run_ids()
valRunID = [runIDs[split_id]]
trainRunIDs = runIDs[:split_id] + runIDs[split_id+1:]
results_folder = f'split_{split_id}'
train(
...
trainRunIDs=trainRunIDs,
validateRunIDs=valRunID,
...
)
Loss Functions¶
Choosing the right loss function is crucial for effective particle segmentation, especially given the class imbalance typical in cryo-ET data where background voxels vastly outnumber particle voxels.
- TverskyLoss: Generalizes Dice loss with separate weights for false positives (
alpha
) and false negatives (beta
). Excellent for imbalanced data - use higherbeta
(e.g., 0.7) to penalize missed particles more than false detections. - DiceLoss: Standard choice for segmentation, works well when classes are relatively balanced.
Octopi Custom Loss Functions:
- FocalTverskyLoss: Combines Tversky's class balancing with Focal's hard example mining. The
gamma
parameter adds focusing power to the Tversky formulation. - WeightedFocalTverskyLoss: Weighted combination of separate Focal and Tversky losses, allowing fine-tuned control over both class imbalance and hard example emphasis through
weight_tversky
andweight_focal
parameters.
Model Exploration¶
In cases where we'd like to automatically explore the model architecture landscape to determine which model configuration would be optimal for our given experiment, we can use the ModelSearchSubmit
class. Here, the Bayesian optimizer will explore various loss functions and their associated hyperparameters, as well as architecture parameters which are defined in the model class. This automated approach significantly reduces the number of input parameters you need to provide, as the system intelligently searches through the hyperparameter space.
The automated search process uses Optuna's Bayesian optimization to efficiently explore combinations of:
- Loss function types and their hyperparameters (alpha, beta, gamma values)
- Model architecture parameters (channel sizes, dropout rates, number of residual units)
This is particularly valuable when working with new datasets or particle types where optimal configurations are unknown.
from octopi.pytorch.model_search_submitter import ModelSearchSubmit
config = 'config.json'
target_name = 'targets'
target_user_id = 'octopi'
target_session_id = '1'
tomo_algorithm = 'denoised'
voxel_size = 10
Nclass = 6 # number of objects + 1 for background
optimizer = ModelSearchSubmit(
config, target_name, target_user_id, target_session_id,
tomo_algorithm, voxel_size, Nclass, 'UNet' )
optimizer.run_model_search()
💡 ModelSearchSubmit Class Reference
ModelSearchSubmit(copick_config, target_name, target_user_id, target_session_id, tomo_algorithm, voxel_size, Nclass, model_type, best_metric='avg_f1', num_epochs=1000, num_trials=100, data_split=0.8, random_seed=42, val_interval=10, tomo_batch_size=15, trainRunIDs=None, validateRunIDs=None, mlflow_experiment_name='explore')
Initialize the ModelSearch class for architecture search with Optuna.
Parameters:
copick_config
(str or dict): Path to the CoPick configuration file or a dictionary for multi-config trainingtarget_name
(str): Name of the target for segmentationtarget_user_id
(str): User ID for target trackingtarget_session_id
(str): Session ID for target trackingtomo_algorithm
(str): Tomogram algorithm to usevoxel_size
(float): Voxel size for tomogramsNclass
(int): Number of prediction classesmodel_type
(str): Type of model to use (e.g., 'UNet')best_metric
(str): Metric to optimize (default: 'avg_f1')num_epochs
(int): Number of epochs per trial (default: 1000)num_trials
(int): Number of trials for hyperparameter optimization (default: 100)data_split
(float): Data split ratio for train/validation (default: 0.8)random_seed
(int): Seed for reproducibility (default: 42)val_interval
(int): Validation interval during training (default: 10)tomo_batch_size
(int): Batch size for tomogram loading (default: 15)trainRunIDs
(List[str]): List of training run IDs (default: None - uses all available)validateRunIDs
(List[str]): List of validation run IDs (default: None - uses all available)mlflow_experiment_name
(str): MLflow experiment name for tracking (default: 'explore')
Methods:
run_model_search()
: Executes the Bayesian optimization search across the defined parameter space
Outputs:
- MLflow experiment logs with trial results and metrics
- Best model configuration and weights
- Hyperparameter optimization history and visualizations
Next Steps¶
Once you've successfully trained your model, you're ready to move to the inference stage. The trained model outputs (.pth
weights and .yaml
configuration) from this training process will be used directly in the inference pipeline.
Ready to apply your trained model? Head to the Inference Guide to learn how to:
- Run segmentation on new tomograms using your trained weights
- Perform particle localization from segmentation masks
- Evaluate your model's performance against ground truth data
Want to experiment with custom architectures? Check out the Adding New Models template to:
- Implement custom neural network architectures
- Integrate new model types into the octopi framework
- Extend the automated model search capabilities
Pro tip: If you're unsure about optimal hyperparameters, consider running the automated model exploration first before manual training - it can save significant time and often discovers configurations you might not have considered!