{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "9c8899e7", "metadata": {}, "source": [ "# Training a PyTorch Model\n", "\n", "This tutorial shows how to train a Logistic Regression model in PyTorch using the Census API's `experimental.ml.ExperimentDataPipe` class. This is intended only to demonstrate the use of the `ExperimentDataPipe`, and not as an example of how to train a biologically useful model.\n", "\n", "This tutorial assumes a basic familiarity with PyTorch and the Census API. See the [Querying and fetching the single-cell data and cell/gene metadata](https://chanzuckerberg.github.io/cellxgene-census/notebooks/api_demo/census_query_extract.html) notebook tutorial for a quick primer on Census API usage.\n", "\n", "**Contents**\n", "\n", "* [Open the Census](#Open-the-Census)\n", "* [Create a DataLoader](#Create-a-DataLoader)\n", "* [Define the model](#Define-the-model)\n", "* [Train the model](#Train-the-model)\n", "* [Make predictions with the model](#Make-predictions-with-the-model)\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "f874fb88", "metadata": {}, "source": [ "## Open the Census\n", "\n", "First, obtain a handle to the Census data, in the usual manner:" ] }, { "cell_type": "code", "execution_count": 25, "id": "c3dd549f", "metadata": { "ExecuteTime": { "end_time": "2023-10-09T18:20:21.600206Z", "start_time": "2023-10-09T18:20:19.390343Z" }, "execution": { "iopub.execute_input": "2023-07-28T16:33:00.392773Z", "iopub.status.busy": "2023-07-28T16:33:00.392516Z", "iopub.status.idle": "2023-07-28T16:33:02.881471Z", "shell.execute_reply": "2023-07-28T16:33:02.880857Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The \"stable\" release is currently 2023-07-25. Specify 'census_version=\"2023-07-25\"' in future calls to open_soma() to ensure data consistency.\n" ] } ], "source": [ "import cellxgene_census\n", "\n", "census = cellxgene_census.open_soma()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "580b29f2", "metadata": {}, "source": [ "## Create an ExperimentDataPipe\n", "\n", "To train a model in PyTorch using this `census` data object, first instantiate an `ExperimentDataPipe` as follows:" ] }, { "cell_type": "code", "execution_count": 26, "id": "54896e6f", "metadata": { "ExecuteTime": { "end_time": "2023-10-09T18:20:22.278894Z", "start_time": "2023-10-09T18:20:21.830683Z" }, "execution": { "iopub.execute_input": "2023-07-28T16:33:02.884588Z", "iopub.status.busy": "2023-07-28T16:33:02.884133Z", "iopub.status.idle": "2023-07-28T16:33:03.356736Z", "shell.execute_reply": "2023-07-28T16:33:03.356115Z" } }, "outputs": [], "source": [ "import cellxgene_census.experimental.ml as census_ml\n", "import tiledbsoma as soma\n", "\n", "experiment = census[\"census_data\"][\"homo_sapiens\"]\n", "\n", "experiment_datapipe = census_ml.ExperimentDataPipe(\n", " experiment,\n", " measurement_name=\"RNA\",\n", " X_name=\"raw\",\n", " obs_query=soma.AxisQuery(value_filter=\"tissue_general == 'tongue' and is_primary_data == True\"),\n", " obs_column_names=[\"cell_type\"],\n", " batch_size=128,\n", " shuffle=True,\n", " soma_chunk_size=10_000,\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "6c7c17c3", "metadata": {}, "source": [] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### `ExperimentDataPipe` class explained\n", "\n", "This class provides an implementation of PyTorch's [DataPipe interface](https://pytorch.org/data/main/torchdata.datapipes.iter.html), which defines a common mechanism for wrapping and accessing training data from any underlying source. The `ExperimentDataPipe` class encapsulates the details of querying and retrieving Census data from a single SOMA `Experiment` and returning it to the caller as PyTorch Tensors. Most importantly, it retrieves the data lazily from the Census in batches, avoiding having to load the entire training dataset into memory at once. (Note: PyTorch also provides `DataSet` as a legacy interface for wrapping and accessing training data sources, but a `DataPipe` can be used interchangeably.)\n" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### `ExperimentDataPipe` parameters explained\n", "\n", "The constructor only requires a single parameter, `experiment`, which is a `soma.Experiment` containing the data of the organism to be used for training.\n", "\n", "To retrieve a subset of the Experiment's data, along either the `obs` or `var` axes, you may specify query filters via the `obs_query` and `var_query` parameters, which are both `soma.AxisQuery` objects.\n", "\n", "The values for the prediction label(s) that you intend to use for training are specified via the `obs_column_names` array.\n", "\n", "The `batch_size` allows you to specify the number of obs rows (cells) to be returned by each return PyTorch tensor. You may exclude this parameter if you want single rows (`batch_size=1`).\n", "\n", "The `shuffle` flag allows you to randomize the ordering of the training data for each training epoch. Note:\n", "* You should use this flag instead of the `DataLoader` `shuffle` flag, as `DataLoader` does not support shuffling when used with an `IterDataPipe` dataset.\n", "* PyTorch's TorchData library provides a [Shuffler](https://pytorch.org/data/main/generated/torchdata.datapipes.iter.Shuffler.html) `DataPipe`, which is alternate mechanism one can use to perform shuffling of an `IterDataPipe`. However, the `Shuffler` will not \"globally\" randomize the training data, as it only \"locally\" randomizes the ordering of the training data within fixed-size \"windows\". Due to the layout of Census data, a given \"window\" of Census data may be highly homogeneous in terms of its `obs` axis attribute values, and so this shuffling strategy may not provide sufficient randomization for certain types of models.\n", "\n", "The `soma_chunk_size` sets the number of rows of data that are retrieved from the Census and held in memory at a given time. This controls\n", " the maximum memory usage of the `ExperimentDataPipe`. Smaller values will require less memory but will also result in lower read performance. If you are running out of memory when training a model, try reducing this value. The default is set to retrieve ~1GB of data per chunk, which takes into account how many `var` (gene) columns are being requested. This parameter also affects the granularity of the \"global\" shuffling step when `shuffle=True` (see ``shuffle`` parameter API docs for details)." ] }, { "attachments": {}, "cell_type": "markdown", "id": "84ac17d2", "metadata": {}, "source": [ "You can inspect the shape of the full dataset, without causing the full dataset to be loaded:" ] }, { "cell_type": "code", "execution_count": 27, "id": "70a2ddbe", "metadata": { "ExecuteTime": { "end_time": "2023-10-09T18:20:32.049618Z", "start_time": "2023-10-09T18:20:22.821101Z" }, "execution": { "iopub.execute_input": "2023-07-28T16:33:03.359927Z", "iopub.status.busy": "2023-07-28T16:33:03.359450Z", "iopub.status.idle": "2023-07-28T16:33:05.524015Z", "shell.execute_reply": "2023-07-28T16:33:05.523337Z" } }, "outputs": [ { "data": { "text/plain": [ "(15020, 60664)" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "experiment_datapipe.shape" ] }, { "attachments": {}, "cell_type": "markdown", "id": "5251109a", "metadata": {}, "source": [ "## Split the dataset\n", "\n", "You may split the overall dataset into the typical training, validation, and test sets by using the PyTorch [RandomSplitter](https://pytorch.org/data/main/generated/torchdata.datapipes.iter.RandomSplitter.html#torchdata.datapipes.iter.RandomSplitter) `DataPipe`. Using PyTorch's functional form for chaining `DataPipe`s, this is done as follows:" ] }, { "cell_type": "code", "execution_count": 28, "id": "133f594f", "metadata": { "ExecuteTime": { "end_time": "2023-10-09T18:20:32.052795Z", "start_time": "2023-10-09T18:20:32.051289Z" }, "execution": { "iopub.execute_input": "2023-07-28T16:33:05.527106Z", "iopub.status.busy": "2023-07-28T16:33:05.526540Z", "iopub.status.idle": "2023-07-28T16:33:05.530907Z", "shell.execute_reply": "2023-07-28T16:33:05.530386Z" } }, "outputs": [], "source": [ "train_datapipe, test_datapipe = experiment_datapipe.random_split(weights={\"train\": 0.8, \"test\": 0.2}, seed=1)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "a825bccf", "metadata": {}, "source": [ "## Create the DataLoader\n", "\n", "With the full set of DataPipe operations chained together, we can now instantiate a PyTorch [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) on the training data. " ] }, { "cell_type": "code", "execution_count": 29, "id": "39d30df2", "metadata": { "ExecuteTime": { "end_time": "2023-10-09T18:20:32.056886Z", "start_time": "2023-10-09T18:20:32.052898Z" }, "execution": { "iopub.execute_input": "2023-07-28T16:33:05.538514Z", "iopub.status.busy": "2023-07-28T16:33:05.538206Z", "iopub.status.idle": "2023-07-28T16:33:05.541008Z", "shell.execute_reply": "2023-07-28T16:33:05.540445Z" } }, "outputs": [], "source": [ "experiment_dataloader = census_ml.experiment_dataloader(train_datapipe)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "8a3cbe3f", "metadata": {}, "source": [ "Alternately, you can instantiate a `DataLoader` object directly via its constructor. However, many of the parameters are not usable with iterable-style DataPipes, which is the case for `ExperimentDataPipe`. In particular, the `shuffle`, `batch_size`, `sampler`, `batch_sampler`, `collate_fn` parameters should not be specified. Using `experiment_dataloader` helps enforce correct usage." ] }, { "attachments": {}, "cell_type": "markdown", "id": "fb9e93b6", "metadata": {}, "source": [ "## Define the model\n", "\n", "With the training data retrieval code now in place, we can move on to defining a simple logistic regression model, using PyTorch's `torch.nn.Linear` class:" ] }, { "cell_type": "code", "execution_count": 30, "id": "6b792b4b", "metadata": { "ExecuteTime": { "end_time": "2023-10-09T18:20:32.060262Z", "start_time": "2023-10-09T18:20:32.058875Z" }, "execution": { "iopub.execute_input": "2023-07-28T16:33:05.543534Z", "iopub.status.busy": "2023-07-28T16:33:05.543229Z", "iopub.status.idle": "2023-07-28T16:33:05.546861Z", "shell.execute_reply": "2023-07-28T16:33:05.546267Z" } }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "class LogisticRegression(torch.nn.Module):\n", " def __init__(self, input_dim, output_dim):\n", " super(LogisticRegression, self).__init__() # noqa: UP008\n", " self.linear = torch.nn.Linear(input_dim, output_dim)\n", "\n", " def forward(self, x):\n", " outputs = torch.sigmoid(self.linear(x))\n", " return outputs" ] }, { "attachments": {}, "cell_type": "markdown", "id": "0e1752ef", "metadata": {}, "source": [ "Next, we define a function to train the model for a single epoch:" ] }, { "cell_type": "code", "execution_count": 31, "id": "b744cd21", "metadata": { "ExecuteTime": { "end_time": "2023-10-09T18:20:32.310829Z", "start_time": "2023-10-09T18:20:32.307661Z" }, "execution": { "iopub.execute_input": "2023-07-28T16:33:05.549312Z", "iopub.status.busy": "2023-07-28T16:33:05.549001Z", "iopub.status.idle": "2023-07-28T16:33:05.562074Z", "shell.execute_reply": "2023-07-28T16:33:05.561501Z" } }, "outputs": [], "source": [ "def train_epoch(model, train_dataloader, loss_fn, optimizer, device):\n", " model.train()\n", " train_loss = 0\n", " train_correct = 0\n", " train_total = 0\n", "\n", " for batch in train_dataloader:\n", " optimizer.zero_grad()\n", " X_batch, y_batch = batch\n", "\n", " X_batch = X_batch.float().to(device)\n", "\n", " # Perform prediction\n", " outputs = model(X_batch)\n", "\n", " # Determine the predicted label\n", " probabilities = torch.nn.functional.softmax(outputs, 1)\n", " predictions = torch.argmax(probabilities, axis=1)\n", "\n", " # Compute the loss and perform back propagation\n", "\n", " # Exclude the cell_type labels, which are in the second column\n", " y_batch = y_batch[:, 1]\n", " y_batch = y_batch.to(device)\n", "\n", " train_correct += (predictions == y_batch).sum().item()\n", " train_total += len(predictions)\n", "\n", " loss = loss_fn(outputs, y_batch.long())\n", " train_loss += loss.item()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " train_loss /= train_total\n", " train_accuracy = train_correct / train_total\n", " return train_loss, train_accuracy" ] }, { "attachments": {}, "cell_type": "markdown", "id": "a0a9fd7e", "metadata": {}, "source": [ "Note the line, `X_batch, y_batch = batch`. Since the `train_dataloader` was configured with `batch_size=16`, these variables will hold tensors of rank 2. The `X_batch` tensor will appear, for example, as:\n", "\n", "```\n", "tensor([[0., 0., 0., ..., 1., 0., 0.],\n", " [0., 0., 2., ..., 0., 3., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " ...,\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 1., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 8.]])\n", " \n", "```\n", "\n", "For `batch_size=1`, the tensors will be of rank 1. The `X_batch` tensor will appear, for example, as:\n", "\n", "```\n", "tensor([0., 0., 0., ..., 1., 0., 0.])\n", "```\n", " \n", "Secondly, note the line, `y_batch = y_batch[:, 1]`. This line is extracting the user-specified `obs` `cell_type` training labels from the second column of the `y_batch` rank 2 Tensor. For example, this would take a `y_batch` tensor that looks like:\n", "```\n", "tensor([[42496620, 1],\n", " [42496621, 1],\n", " [42496622, 3],\n", " ...,\n", " [42496633, 2],\n", " [42496634, 1],\n", " [42496635, 4]], dtype=torch.int32)\n", " \n", "```\n", "and return:\n", "```\n", "tensor([1, 1, 3, ..., 2, 1, 4])\n", "\n", "```\n", "Note that cell type values are integer-encoded values, which can be decoded using `experiment_datapipe.obs_encoders` (more on this below).\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "79f8b731", "metadata": {}, "source": [ "## Train the model\n", "\n", "Finally, we are ready to train the model. Here we instantiate the model, a loss function, and an optimization method and then iterate through the desired number of training epochs. Note how the `train_dataloader` is passed into `train_epoch`, where for each epoch it will provide a new iterator through the training dataset." ] }, { "cell_type": "code", "execution_count": 32, "id": "733ec2fb", "metadata": { "ExecuteTime": { "end_time": "2023-10-09T18:29:31.028253Z", "start_time": "2023-10-09T18:20:32.311816Z" }, "execution": { "iopub.execute_input": "2023-07-28T16:33:05.564772Z", "iopub.status.busy": "2023-07-28T16:33:05.564454Z", "iopub.status.idle": "2023-07-28T16:34:04.801559Z", "shell.execute_reply": "2023-07-28T16:34:04.800846Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1: Train Loss: 0.0167253 Accuracy 0.4856\n", "Epoch 2: Train Loss: 0.0156710 Accuracy 0.4943\n", "Epoch 3: Train Loss: 0.0149408 Accuracy 0.4813\n", "Epoch 4: Train Loss: 0.0144469 Accuracy 0.5040\n", "Epoch 5: Train Loss: 0.0141749 Accuracy 0.5669\n", "Epoch 6: Train Loss: 0.0139776 Accuracy 0.6672\n", "Epoch 7: Train Loss: 0.0138565 Accuracy 0.7920\n", "Epoch 8: Train Loss: 0.0138094 Accuracy 0.8088\n", "Epoch 9: Train Loss: 0.0136689 Accuracy 0.8757\n", "Epoch 10: Train Loss: 0.0136101 Accuracy 0.8923\n" ] } ], "source": [ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "\n", "# The size of the input dimension is the number of genes\n", "input_dim = experiment_datapipe.shape[1]\n", "\n", "# The size of the output dimension is the number of distinct cell_type values\n", "cell_type_encoder = experiment_datapipe.obs_encoders[\"cell_type\"]\n", "output_dim = len(cell_type_encoder.classes_)\n", "\n", "model = LogisticRegression(input_dim, output_dim).to(device)\n", "loss_fn = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-05)\n", "\n", "for epoch in range(10):\n", " train_loss, train_accuracy = train_epoch(model, experiment_dataloader, loss_fn, optimizer, device)\n", " print(f\"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "5e0ffb48", "metadata": {}, "source": [ "## Make predictions with the model\n", "\n", "To make predictions with the model, we first create a new `DataLoader` using the `test_datapipe`, which provides the \"test\" split of the original dataset. For this example, we will only make predictions on a single batch of data from the test split." ] }, { "cell_type": "code", "execution_count": 33, "id": "d3e33edc", "metadata": { "ExecuteTime": { "end_time": "2023-10-09T18:29:59.425527Z", "start_time": "2023-10-09T18:29:31.705548Z" }, "execution": { "iopub.execute_input": "2023-07-28T16:34:04.804402Z", "iopub.status.busy": "2023-07-28T16:34:04.803987Z", "iopub.status.idle": "2023-07-28T16:34:09.331800Z", "shell.execute_reply": "2023-07-28T16:34:09.331168Z" } }, "outputs": [], "source": [ "experiment_dataloader = census_ml.experiment_dataloader(test_datapipe)\n", "X_batch, y_batch = next(iter(experiment_dataloader))" ] }, { "attachments": {}, "cell_type": "markdown", "id": "19fabd54", "metadata": {}, "source": [ "Next, we invoke the model on the `X_batch` input data and extract the predictions:" ] }, { "cell_type": "code", "execution_count": 34, "id": "00e12182", "metadata": { "ExecuteTime": { "end_time": "2023-10-09T18:29:59.438079Z", "start_time": "2023-10-09T18:29:59.429107Z" }, "execution": { "iopub.execute_input": "2023-07-28T16:34:09.334916Z", "iopub.status.busy": "2023-07-28T16:34:09.334470Z", "iopub.status.idle": "2023-07-28T16:34:09.340880Z", "shell.execute_reply": "2023-07-28T16:34:09.340293Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([ 1, 11, 1, 1, 5, 1, 1, 1, 1, 5, 1, 5, 1, 5, 5, 8, 1, 1,\n", " 7, 1, 5, 5, 8, 5, 5, 1, 1, 1, 1, 8, 9, 1, 1, 8, 1, 1,\n", " 1, 11, 5, 1, 8, 5, 5, 1, 5, 1, 5, 5, 1, 5, 9, 8, 1, 1,\n", " 1, 5, 5, 5, 1, 5, 1, 5, 1, 1, 5, 8, 1, 1, 1, 1, 7, 1,\n", " 5, 1, 1, 5, 5, 1, 1, 8, 5, 5, 8, 1, 1, 1, 5, 5, 5, 1,\n", " 5, 1, 5, 5, 1, 1, 5, 1, 5, 1, 1, 1, 5, 1, 1, 1, 9, 5,\n", " 1, 1, 7, 1, 1, 1, 1, 8, 1, 1, 5, 5, 1, 5, 1, 1, 1, 5,\n", " 8, 1])" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model.eval()\n", "\n", "model.to(device)\n", "outputs = model(X_batch.to(device))\n", "\n", "probabilities = torch.nn.functional.softmax(outputs, 1)\n", "predictions = torch.argmax(probabilities, axis=1)\n", "\n", "display(predictions)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "7cb88a5f", "metadata": {}, "source": [ "The predictions are returned as the encoded values of `cell_type` label. To recover the original cell type labels as strings, we decode using the encoders from `experiment_datapipe.obs_encoders`.\n", "\n", "At inference time, if the model inputs are not obtained via an `ExperimentDataPipe`, one could pickle the encoder at training time and save it along with the model. Then, at inference time it can be unpickled and used as shown below." ] }, { "cell_type": "code", "execution_count": 35, "id": "1cfff865", "metadata": { "ExecuteTime": { "end_time": "2023-10-09T18:29:59.441907Z", "start_time": "2023-10-09T18:29:59.439561Z" }, "execution": { "iopub.execute_input": "2023-07-28T16:34:09.343375Z", "iopub.status.busy": "2023-07-28T16:34:09.343131Z", "iopub.status.idle": "2023-07-28T16:34:09.347842Z", "shell.execute_reply": "2023-07-28T16:34:09.347311Z" } }, "outputs": [ { "data": { "text/plain": [ "array(['basal cell', 'vein endothelial cell', 'basal cell', 'basal cell',\n", " 'epithelial cell', 'basal cell', 'basal cell', 'basal cell',\n", " 'basal cell', 'epithelial cell', 'basal cell', 'epithelial cell',\n", " 'basal cell', 'epithelial cell', 'epithelial cell', 'leukocyte',\n", " 'basal cell', 'basal cell', 'keratinocyte', 'basal cell',\n", " 'epithelial cell', 'epithelial cell', 'leukocyte',\n", " 'epithelial cell', 'epithelial cell', 'basal cell', 'basal cell',\n", " 'basal cell', 'basal cell', 'leukocyte', 'pericyte', 'basal cell',\n", " 'basal cell', 'leukocyte', 'basal cell', 'basal cell',\n", " 'basal cell', 'vein endothelial cell', 'epithelial cell',\n", " 'basal cell', 'leukocyte', 'epithelial cell', 'epithelial cell',\n", " 'basal cell', 'epithelial cell', 'basal cell', 'epithelial cell',\n", " 'epithelial cell', 'basal cell', 'epithelial cell', 'pericyte',\n", " 'leukocyte', 'basal cell', 'basal cell', 'basal cell',\n", " 'epithelial cell', 'epithelial cell', 'epithelial cell',\n", " 'basal cell', 'epithelial cell', 'basal cell', 'epithelial cell',\n", " 'basal cell', 'basal cell', 'epithelial cell', 'leukocyte',\n", " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n", " 'keratinocyte', 'basal cell', 'epithelial cell', 'basal cell',\n", " 'basal cell', 'epithelial cell', 'epithelial cell', 'basal cell',\n", " 'basal cell', 'leukocyte', 'epithelial cell', 'epithelial cell',\n", " 'leukocyte', 'basal cell', 'basal cell', 'basal cell',\n", " 'epithelial cell', 'epithelial cell', 'epithelial cell',\n", " 'basal cell', 'epithelial cell', 'basal cell', 'epithelial cell',\n", " 'epithelial cell', 'basal cell', 'basal cell', 'epithelial cell',\n", " 'basal cell', 'epithelial cell', 'basal cell', 'basal cell',\n", " 'basal cell', 'epithelial cell', 'basal cell', 'basal cell',\n", " 'basal cell', 'pericyte', 'epithelial cell', 'basal cell',\n", " 'basal cell', 'keratinocyte', 'basal cell', 'basal cell',\n", " 'basal cell', 'basal cell', 'leukocyte', 'basal cell',\n", " 'basal cell', 'epithelial cell', 'epithelial cell', 'basal cell',\n", " 'epithelial cell', 'basal cell', 'basal cell', 'basal cell',\n", " 'epithelial cell', 'leukocyte', 'basal cell'], dtype=object)" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "cell_type_encoder = experiment_datapipe.obs_encoders[\"cell_type\"]\n", "\n", "predicted_cell_types = cell_type_encoder.inverse_transform(predictions.cpu())\n", "\n", "display(predicted_cell_types)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "16010d09", "metadata": {}, "source": [ "Finally, we create a Pandas DataFrame to examine the predictions:" ] }, { "cell_type": "code", "execution_count": 36, "id": "f4ac8087", "metadata": { "ExecuteTime": { "end_time": "2023-10-09T18:29:59.471320Z", "start_time": "2023-10-09T18:29:59.443175Z" }, "execution": { "iopub.execute_input": "2023-07-28T16:34:09.350404Z", "iopub.status.busy": "2023-07-28T16:34:09.350006Z", "iopub.status.idle": "2023-07-28T16:34:09.701102Z", "shell.execute_reply": "2023-07-28T16:34:09.700533Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
actual cell typepredicted cell type
0basal cellbasal cell
1vein endothelial cellvein endothelial cell
2basal cellbasal cell
3basal cellbasal cell
4epithelial cellepithelial cell
.........
123basal cellbasal cell
124basal cellbasal cell
125epithelial cellepithelial cell
126leukocyteleukocyte
127basal cellbasal cell
\n", "

128 rows × 2 columns

\n", "
" ], "text/plain": [ " actual cell type predicted cell type\n", "0 basal cell basal cell\n", "1 vein endothelial cell vein endothelial cell\n", "2 basal cell basal cell\n", "3 basal cell basal cell\n", "4 epithelial cell epithelial cell\n", ".. ... ...\n", "123 basal cell basal cell\n", "124 basal cell basal cell\n", "125 epithelial cell epithelial cell\n", "126 leukocyte leukocyte\n", "127 basal cell basal cell\n", "\n", "[128 rows x 2 columns]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import pandas as pd\n", "\n", "display(\n", " pd.DataFrame(\n", " {\n", " \"actual cell type\": cell_type_encoder.inverse_transform(y_batch[:, 1].numpy()),\n", " \"predicted cell type\": predicted_cell_types,\n", " }\n", " )\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.10" } }, "nbformat": 4, "nbformat_minor": 5 }