Training a PyTorch Model
This tutorial demonstrates training a simple Logistic Regression model using Census data, PyTorch, and TileDB-SOMA-ML, which provides a PyTorch loader for SOMA datasets like Census. Note: an earlier version of this notebook used a prototype loader in the ``cellxgene_census.experimental`` API, now replaced by TileDB-SOMA-ML.
We assume basic familiarity with PyTorch and the Census API. See the Querying and fetching the single-cell data and cell/gene metadata notebook tutorial for a quick primer on Census API usage.
Contents
Create an ExperimentDataset
Split the dataset
Create the DataLoader
Define the model
Train the model
Make predictions with the model
Open the Census
First, obtain a handle to the Census data, in the usual manner:
[1]:
import cellxgene_census
census = cellxgene_census.open_soma()
The "stable" release is currently 2025-01-30. Specify 'census_version="2025-01-30"' in future calls to open_soma() to ensure data consistency.
Create an ExperimentDataset
To set up a PyTorch dataset from a Census slice, open an `ExperimentAxisQuery
<https://tiledbsoma.readthedocs.io/en/stable/python-tiledbsoma-experimentaxisquery.html>`__ and create a TileDB-SOMA-ML `ExperimentDataset
<https://single-cell-data.github.io/TileDB-SOMA-ML/#tiledbsoma_ml.ExperimentDataset>`__ for it. We will also prepare a scikit-learn `LabelEncoder
<https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html>`__ to help with mapping between
Census cell types (categorical metadata) and PyTorch tensors.
[2]:
import tiledbsoma as soma
from sklearn.preprocessing import LabelEncoder
from tiledbsoma_ml import ExperimentDataset
experiment = census["census_data"]["homo_sapiens"]
with experiment.axis_query(
measurement_name="RNA",
obs_query=soma.AxisQuery(value_filter="tissue_general == 'tongue' and is_primary_data == True"),
) as query:
experiment_dataset = ExperimentDataset(
query,
layer_name="raw",
obs_column_names=["cell_type"],
batch_size=128,
shuffle=True,
seed=111,
)
obs_df = query.obs(column_names=["cell_type"]).concat().to_pandas()
cell_type_encoder = LabelEncoder().fit(obs_df["cell_type"].unique())
`ExperimentDataset
<https://single-cell-data.github.io/TileDB-SOMA-ML/#tiledbsoma_ml.ExperimentDataset>`__ class explained
This class provides an implementation of PyTorch’s IterableDataset
interface for accessing data from abstract sources. It encapsulates streaming the result set of a SOMA `ExperimentAxisQuery
<https://tiledbsoma.readthedocs.io/en/stable/python-tiledbsoma-experimentaxisquery.html>`__ in a series of “batches,” each consisting of a NumPy ndarray
(batch of X results) and a Pandas DataFrame
(corresponding batch of obs results). Most importantly, it avoids loading large result sets into
memory all at once.
The constructor requires the `ExperimentAxisQuery
<https://tiledbsoma.readthedocs.io/en/stable/python-tiledbsoma-experimentaxisquery.html>`__ defining the desired slice of Census data, and the name of the X layer to access.
obs_column_names
sets the columns to be included in the DataFrame for each batch, to be used as data labels or model inputs.
The batch_size
parameter specifies the target number of rows (cells) in each batch.
The shuffle
flag supports randomizing the ordering of the training data for each training epoch (default: True
). This should be used instead of `DataLoader
<https://pytorch.org/docs/stable/data.html>`__’s shuffle
flag, as the implementation is more efficient.
You can inspect the shape of the full dataset, without causing the full dataset to be loaded. The shape
property returns the number of batches on the first dimension:
[3]:
experiment_dataset.shape
[3]:
(303, 61888)
[4]:
experiment_dataset.query_ids.obs_joinids.shape # total result count
[4]:
(38754,)
Split the dataset
You may split the query results into the typical training and test sets using the `ExperimentDataset.random_split()
<https://single-cell-data.github.io/TileDB-SOMA-ML/#tiledbsoma_ml.ExperimentDataset.random_split>`__ method:
[5]:
train_dataset, test_dataset = experiment_dataset.random_split(
0.8,
0.2,
seed=111,
)
train_dataset.shape, test_dataset.shape
[5]:
((243, 61888), (61, 61888))
[6]:
train_dataset.query_ids.obs_joinids.shape, test_dataset.query_ids.obs_joinids.shape
[6]:
((31003,), (7751,))
Create the DataLoader
Now you can prepare a PyTorch `DataLoader
<https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`__ on the training data. While not shown here, the DataLoader can be configured with multiple worker processes to stream the result batches as quickly as possible.
[7]:
from tiledbsoma_ml import experiment_dataloader
train_dataloader = experiment_dataloader(train_dataset)
(Instantiating a PyTorch DataLoader
object directly is not recommended, as `experiment_dataloader()
<https://single-cell-data.github.io/TileDB-SOMA-ML/#tiledbsoma_ml.experiment_dataloader>`__ enforces correct and performant usage.)
Define the model
With the training data retrieval code now in place, we can move on to defining a simple logistic regression model, using PyTorch’s torch.nn.Linear
class:
[8]:
import torch
class LogisticRegression(torch.nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.linear = torch.nn.Linear(input_dim, output_dim)
def forward(self, x):
outputs = torch.sigmoid(self.linear(x))
return outputs
Next, we define a function to train the model for a single epoch:
[9]:
def train_epoch(model, train_dataloader, loss_fn, optimizer, device):
model.train()
train_loss = 0
train_correct = 0
train_total = 0
for X_batch, obs_batch in train_dataloader:
optimizer.zero_grad()
# convert X_batch numpy ndarray into PyTorch tensor
X_batch = torch.from_numpy(X_batch).float().to(device)
# Perform prediction
outputs = model(X_batch)
# Determine the predicted label
probabilities = torch.nn.functional.softmax(outputs, 1)
predictions = torch.argmax(probabilities, axis=1)
# Compute the loss using encoded labels and perform back propagation
label_batch = torch.from_numpy(cell_type_encoder.transform(obs_batch["cell_type"])).to(device)
train_correct += (predictions == label_batch).sum().item()
train_total += len(predictions)
loss = loss_fn(outputs, label_batch.long())
train_loss += loss.item()
loss.backward()
optimizer.step()
train_loss /= train_total
train_accuracy = train_correct / train_total
return train_loss, train_accuracy
The X_batch
ndarray (later tensor) has the slice of the RNA X matrix for the current batch of results. Similarly, the obs_batch
dataframe has the slice of obs
with cell_type
metadata. We use cell_type_encoder
to encode these categorical cell types into integer-valued tensors for comparison with the model predictions.
Train the model
Finally, we are ready to train the model. Here we instantiate the model, a loss function, and an optimization method, then iterate through the desired number of training epochs. Note how the train_dataloader
is passed into train_epoch
, where for each epoch it will provide a new iterator through the training dataset.
[10]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# The size of the input dimension is the number of genes
input_dim = experiment_dataset.shape[1]
# The size of the output dimension is the number of distinct cell_type values
output_dim = len(cell_type_encoder.classes_)
model = LogisticRegression(input_dim, output_dim).to(device)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
for epoch in range(25):
train_loss, train_accuracy = train_epoch(model, train_dataloader, loss_fn, optimizer, device)
print(f"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}")
Epoch 1: Train Loss: 0.0196637 Accuracy 0.4806
Epoch 2: Train Loss: 0.0186214 Accuracy 0.7383
Epoch 3: Train Loss: 0.0183390 Accuracy 0.8257
Epoch 4: Train Loss: 0.0181079 Accuracy 0.8503
Epoch 5: Train Loss: 0.0180089 Accuracy 0.8603
Epoch 6: Train Loss: 0.0179505 Accuracy 0.8641
Epoch 7: Train Loss: 0.0179011 Accuracy 0.8678
Epoch 8: Train Loss: 0.0178689 Accuracy 0.8694
Epoch 9: Train Loss: 0.0178441 Accuracy 0.8718
Epoch 10: Train Loss: 0.0178233 Accuracy 0.8728
Epoch 11: Train Loss: 0.0178058 Accuracy 0.8741
Epoch 12: Train Loss: 0.0177893 Accuracy 0.8755
Epoch 13: Train Loss: 0.0177776 Accuracy 0.8737
Epoch 14: Train Loss: 0.0177634 Accuracy 0.8771
Epoch 15: Train Loss: 0.0177532 Accuracy 0.8783
Epoch 16: Train Loss: 0.0177437 Accuracy 0.8793
Epoch 17: Train Loss: 0.0177350 Accuracy 0.8802
Epoch 18: Train Loss: 0.0177275 Accuracy 0.8811
Epoch 19: Train Loss: 0.0177217 Accuracy 0.8816
Epoch 20: Train Loss: 0.0177154 Accuracy 0.8823
Epoch 21: Train Loss: 0.0177091 Accuracy 0.8838
Epoch 22: Train Loss: 0.0177030 Accuracy 0.8840
Epoch 23: Train Loss: 0.0176972 Accuracy 0.8847
Epoch 24: Train Loss: 0.0176921 Accuracy 0.8851
Epoch 25: Train Loss: 0.0176756 Accuracy 0.8856
Make predictions with the model
To make predictions with the model, we first create a new DataLoader
using the test_dataset
, which provides the “test” split of the original dataset. For this example, we will only make predictions on a single batch of data from the test split.
[11]:
test_dataloader = experiment_dataloader(test_dataset)
X_batch, obs_batch = next(iter(test_dataloader))
X_batch = torch.from_numpy(X_batch)
true_cell_types = torch.from_numpy(cell_type_encoder.transform(obs_batch["cell_type"]))
Next, we invoke the model on the X_batch
input data and extract the predictions:
[12]:
model.eval()
model.to(device)
outputs = model(X_batch.to(device))
probabilities = torch.nn.functional.softmax(outputs, 1)
predictions = torch.argmax(probabilities, axis=1)
predictions
[12]:
tensor([ 5, 8, 6, 20, 20, 20, 20, 5, 5, 20, 5, 8, 19, 9, 22, 5, 6, 5,
20, 5, 5, 9, 5, 20, 20, 1, 8, 5, 1, 0, 8, 5, 20, 5, 20, 5,
8, 20, 20, 5, 20, 1, 20, 20, 5, 20, 10, 9, 5, 5, 20, 5, 6, 9,
2, 5, 20, 8, 20, 5, 5, 5, 5, 5, 20, 20, 5, 1, 5, 5, 9, 20,
5, 20, 20, 5, 5, 5, 6, 12, 5, 5, 20, 1, 5, 5, 5, 20, 20, 20,
20, 5, 20, 22, 20, 5, 5, 5, 5, 5, 20, 2, 1, 19, 5, 5, 5, 5,
20, 5, 5, 20, 0, 22, 5, 1, 20, 2, 20, 20, 19, 5, 5, 5, 5, 5,
16, 19], device='cuda:0')
The predictions are returned as the encoded values of cell_type
label. To recover the original cell type labels as strings, we decode using the same LabelEncoder
used for training.
At inference time, if the model inputs are not obtained via an ExperimentDataset
, one could pickle the encoder at training time and save it along with the model. Then, at inference time it can be unpickled and used as shown below.
[13]:
predicted_cell_types = cell_type_encoder.inverse_transform(predictions.cpu())
predicted_cell_types
[13]:
array(['basal cell', 'fibroblast', 'endothelial cell',
'stratified squamous epithelial cell',
'stratified squamous epithelial cell',
'stratified squamous epithelial cell',
'stratified squamous epithelial cell', 'basal cell', 'basal cell',
'stratified squamous epithelial cell', 'basal cell', 'fibroblast',
'salivary gland cell', 'macrophage', 'tongue muscle cell',
'basal cell', 'endothelial cell', 'basal cell',
'stratified squamous epithelial cell', 'basal cell', 'basal cell',
'macrophage', 'basal cell', 'stratified squamous epithelial cell',
'stratified squamous epithelial cell',
'CD4-positive, alpha-beta T cell', 'fibroblast', 'basal cell',
'CD4-positive, alpha-beta T cell', 'B cell', 'fibroblast',
'basal cell', 'stratified squamous epithelial cell', 'basal cell',
'stratified squamous epithelial cell', 'basal cell', 'fibroblast',
'stratified squamous epithelial cell',
'stratified squamous epithelial cell', 'basal cell',
'stratified squamous epithelial cell',
'CD4-positive, alpha-beta T cell',
'stratified squamous epithelial cell',
'stratified squamous epithelial cell', 'basal cell',
'stratified squamous epithelial cell', 'mast cell', 'macrophage',
'basal cell', 'basal cell', 'stratified squamous epithelial cell',
'basal cell', 'endothelial cell', 'macrophage',
'CD8-positive, alpha-beta T cell', 'basal cell',
'stratified squamous epithelial cell', 'fibroblast',
'stratified squamous epithelial cell', 'basal cell', 'basal cell',
'basal cell', 'basal cell', 'basal cell',
'stratified squamous epithelial cell',
'stratified squamous epithelial cell', 'basal cell',
'CD4-positive, alpha-beta T cell', 'basal cell', 'basal cell',
'macrophage', 'stratified squamous epithelial cell', 'basal cell',
'stratified squamous epithelial cell',
'stratified squamous epithelial cell', 'basal cell', 'basal cell',
'basal cell', 'endothelial cell', 'mural cell', 'basal cell',
'basal cell', 'stratified squamous epithelial cell',
'CD4-positive, alpha-beta T cell', 'basal cell', 'basal cell',
'basal cell', 'stratified squamous epithelial cell',
'stratified squamous epithelial cell',
'stratified squamous epithelial cell',
'stratified squamous epithelial cell', 'basal cell',
'stratified squamous epithelial cell', 'tongue muscle cell',
'stratified squamous epithelial cell', 'basal cell', 'basal cell',
'basal cell', 'basal cell', 'basal cell',
'stratified squamous epithelial cell',
'CD8-positive, alpha-beta T cell',
'CD4-positive, alpha-beta T cell', 'salivary gland cell',
'basal cell', 'basal cell', 'basal cell', 'basal cell',
'stratified squamous epithelial cell', 'basal cell', 'basal cell',
'stratified squamous epithelial cell', 'B cell',
'tongue muscle cell', 'basal cell',
'CD4-positive, alpha-beta T cell',
'stratified squamous epithelial cell',
'CD8-positive, alpha-beta T cell',
'stratified squamous epithelial cell',
'stratified squamous epithelial cell', 'salivary gland cell',
'basal cell', 'basal cell', 'basal cell', 'basal cell',
'basal cell', 'neutrophil', 'salivary gland cell'], dtype=object)
Finally, we create a Pandas DataFrame to examine the predictions:
[14]:
import pandas as pd
batch_cmp_df = pd.DataFrame(
{
"true cell type": cell_type_encoder.inverse_transform(true_cell_types.ravel().numpy()),
"predicted cell type": predicted_cell_types,
}
)
batch_cmp_df
[14]:
true cell type | predicted cell type | |
---|---|---|
0 | stratified squamous epithelial cell | basal cell |
1 | fibroblast | fibroblast |
2 | basal cell | endothelial cell |
3 | stratified squamous epithelial cell | stratified squamous epithelial cell |
4 | stratified squamous epithelial cell | stratified squamous epithelial cell |
... | ... | ... |
123 | basal cell | basal cell |
124 | basal cell | basal cell |
125 | basal cell | basal cell |
126 | basal cell | neutrophil |
127 | salivary gland cell | salivary gland cell |
128 rows × 2 columns
[15]:
pd.crosstab(
batch_cmp_df["true cell type"],
batch_cmp_df["predicted cell type"],
).replace(0, "")
[15]:
predicted cell type | B cell | CD4-positive, alpha-beta T cell | CD8-positive, alpha-beta T cell | basal cell | endothelial cell | fibroblast | macrophage | mast cell | mural cell | neutrophil | salivary gland cell | stratified squamous epithelial cell | tongue muscle cell |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
true cell type | |||||||||||||
B cell | 2 | ||||||||||||
CD4-positive, alpha-beta T cell | 6 | ||||||||||||
CD8-positive, alpha-beta T cell | 3 | ||||||||||||
Schwann cell | 1 | ||||||||||||
basal cell | 46 | 1 | 1 | 3 | |||||||||
endothelial cell | 3 | ||||||||||||
fibroblast | 6 | ||||||||||||
macrophage | 2 | ||||||||||||
mast cell | 1 | ||||||||||||
monocyte | 3 | ||||||||||||
mural cell | 1 | ||||||||||||
myoepithelial cell | 1 | ||||||||||||
regulatory T cell | 1 | ||||||||||||
salivary gland cell | 2 | 4 | |||||||||||
stratified squamous epithelial cell | 5 | 34 | |||||||||||
tongue muscle cell | 2 |
[ ]: