Skip to content

Asyncronous Multi-GPU Inference

The GPUPool class is a powerful tool in SABER that enables parallel inference across multiple GPUs. This class efficiently distributes deep learning models across available GPU devices, allowing for asynchronous and parallel processing of large datasets.

Overview

GPUPool manages the distribution of tasks across available GPUs and handles the initialization of models on each device. It provides a streamlined interface for:

  • Initializing models on multiple GPUs
  • Distributing tasks across available GPU resources
  • Managing asynchronous execution with proper error handling
  • Tracking progress of batch processing operations

Usage

As a high overview, here is an complete example of applying parallel inference of SABER on 2D data with GPUPool.

from saber.entry_points import parallelization
import glob

# Get all micrograph files
files = glob.glob("path/to/micrographs/*.mrc")

# Create processing pool
pool = parallelization.GPUPool(
    init_fn=initialize_model,
    init_args=("large", model_weights, model_config, target_class),
    verbose=True
)

# Prepare tasks
tasks = [(fname, "output.zarr", 2) for fname in files]

# Execute batch processing
try:
    pool.execute(
        process_task,
        tasks, 
        task_ids=files,
        progress_desc="Processing micrographs"
    )
finally:
    pool.shutdown()

The general workflow involves 3 key steps:

1. Define a task list:

tasks = [(input_file1, output_path1, params1), 
         (input_file2, output_path2, params2),
         ...]

2. Create Model Initialization Function

Define a function that initializes your model on a specific GPU. The GPU ID must be the first parameter:

from saber.segmenters.micro import cryoMicroSegmenter
from saber.classifier.models import common

def initialize_model(
    gpu_id:int, 
    model_weights:str, model_config:str, 
    target_class:int, sam2_cfg:str):
    """Load micrograph segmentation models once per GPU"""

    torch.cuda.set_device(gpu_id)

    # Load models
    predictor = common.get_predictor(model_weights, model_config, gpu_id)
    segmenter = cryoMicroSegmenter(
        sam2_cfg=sam2_cfg,
        deviceID=gpu_id,
        classifier=predictor,
        target_class=target_class
    )

    return {
        'predictor': predictor,
        'segmenter': segmenter
    }

3. Create Processing Function

Define a function that processes each task using the initialized model. The GPU ID and model must be the last two parameters:

def process_task(
    input:str, output: str,
    scale_factor: float, gpu_id, models):

    # Get the Global Zarr Writer
    zwriter = zarr_writer.get_zarr_writer(output)

    # Use pre-loaded segmenter
    segmenter = models['segmenter']        

    # Ensure we're on the correct GPU
    torch.cuda.set_device(gpu_id)

    # Read the Micrograph
    image, pixel_size = io.read_micrograph(input)
    image = image.astype(np.float32)

    # Downsample the input image
    image = FourierRescale2D.run(image, scale_factor)   

    # Produce Initialial Segmentations with SAM2
    segmenter.segment( image, display_image=False )
    (image0, masks_list) = (segmenter.image0, segmenter.masks)

    # Convert Masks to Numpy Array
    masks = mask_filters.masks_to_array(masks_list)

    # Write Run to Zarr
    input = os.path.splitext(os.path.basename(input))[0]
    zwriter.write(run_name=input, image=image0, masks=masks.astype(np.uint8))

Key Parameters

Initialization

GPUPool(
    init_fn,                # Function to initialize model on each GPU
    init_args=None,         # Arguments passed to init_fn
    init_kwargs=None,       # Keyword arguments passed to init_fn
    gpu_ids=None,           # Specific GPU IDs to use (defaults to all available)
    num_workers=None,       # Number of worker processes (defaults to number of GPUs)
    verbose=False           # Enable/disable verbose output
)

Execution

execute(
    fn,                    # Function to execute on each task
    tasks,                 # List of task parameters
    task_ids=None,         # Optional identifiers for each task
    progress_desc=None,    # Description for progress bar
    **kwargs               # Additional keyword arguments passed to fn
)