Run in Google Colab
|
View source on GitHub
|
|
How to Train a Model on MNIST with FiftyOne and Torch#
This recipe demonstrates how to train a PyTorch model on the MNIST dataset using FiftyOneTorchDataset. This is useful when you want to build and evaluate models in Torch while managing your data pipeline directly from FiftyOne. Specifically, it covers:
Loading the MNIST dataset from the Dataset Zoo
Creating train/validation/test splits with FiftyOne’s tagging and random splitting utilities
Building a subset of the dataset for faster experimentation
Running a simple training loop via an external script (mnist_training.py)
Saving model weights for later evaluation or reuse
API references: FiftyOneTorchDataset · GetItem
Setup#
If you haven’t already, install FiftyOne:
[ ]:
!pip install fiftyone
In this tutorial, we’ll use PyTorch for working with tensors and inspecting sample data. To follow along, you’ll need to install torch and torchvision, if necessary:
[ ]:
!pip install torch torchvision
Import Libraries#
[1]:
import fiftyone as fo
import fiftyone.zoo as foz
import fiftyone.utils.random as four
[2]:
import torch
from torch.utils.data import DataLoader
import numpy as np
import torchvision.transforms.v2 as transforms
from torchvision import tv_tensors
import matplotlib.pyplot as plt
import matplotlib.patches as plt_patches
from PIL import Image
import urllib.request
To run this recipe, you’ll need the mnist_training.py script (source on GitHub), which contains a complete PyTorch training loop built on FiftyOneTorchDataset. The following cell downloads it into your working directory so it can be imported directly.
[ ]:
url = "https://cdn.voxel51.com/tutorials_torch_dataset_examples/notebook_simple_training_example/mnist_training.py"
urllib.request.urlretrieve(url, "mnist_training.py")
[ ]:
# utils.py is shared across the torch-dataset-examples notebooks
url = "https://cdn.voxel51.com/tutorials_torch_dataset_examples/notebook_the_cache_field_names_argument/utils.py"
urllib.request.urlretrieve(url, "utils.py")
[3]:
import mnist_training
[4]:
torch.multiprocessing.set_start_method("forkserver")
torch.multiprocessing.set_forkserver_preload(["torch", "fiftyone"])
Training MNIST with FiftyOneTorchDataset#
With our dataset loaded and splits defined, we can call mnist_training.main() directly. Under the hood this uses a FiftyOneTorchDataset and a GetItem to build each split’s dataloader from FiftyOne views.
[ ]:
mnist = foz.load_zoo_dataset("mnist")
mnist.persistent = True
Before defining splits, you can optionally explore the dataset in the FiftyOne App:
[ ]:
fo.launch_app(mnist, auto=False)
Now let’s define a validation split from the non-test samples. FiftyOne’s random_split makes this straightforward:
[ ]:
# remove existing 'train' or 'validation' tags if they exist
mnist.untag_samples(["train", "validation"])
# create a random split on all non-test samples
not_test = mnist.match_tags("test", bool=False)
four.random_split(not_test, {"train": 0.9, "validation": 0.1})
print(mnist.count_sample_tags())
[ ]:
# build a small subset for faster experimentation
samples = []
samples += mnist.match_tags("train").take(1000).values("id")
for tag in ["test", "validation"]:
samples += mnist.match_tags(tag).values("id")
subset = mnist.select(samples)
[ ]:
from pathlib import Path
device = "cuda" if torch.cuda.is_available() else "cpu"
path_to_save_weights = Path("./mnist_weights")
path_to_save_weights.mkdir(parents=True, exist_ok=True)
mnist_training.main(subset, 10, 10, device, str(path_to_save_weights))
Training is complete. Predictions are written back to the underlying MNIST dataset via sample IDs, so opening mnist (the full dataset) will show all test-split predictions for review:
fo.launch_app(mnist)
Understanding the Training Script#
The mnist_training.py script contains the full training loop. Here we walk through its key design decisions.
DataLoader Creation with FiftyOneTorchDataset and GetItem#
create_dataloaders() calls dataset.match_tags(split).to_torch(get_item) for each split tag, converting every FiftyOne view directly into a FiftyOneTorchDataset. The MnistGetItem class — a subclass of GetItem — declares which fields to load and converts each sample into model-ready tensors:
class MnistGetItem(GetItem):
def __init__(self):
super().__init__(
field_mapping={"id": "id", "filepath": "filepath", "label": "ground_truth.label"}
)
def __call__(self, sample):
image = convert_and_normalize(Image.open(sample["filepath"]).convert("RGB"))
label = int(sample["label"][0])
return {"image": image, "label": label, "id": sample["id"]}
The resulting dataset plugs directly into torch.utils.data.DataLoader. The only required addition is worker_init_fn=FiftyOneTorchDataset.worker_init, which lets FiftyOne open its own database connection inside each worker process.
Versatility: Any View Becomes a Split#
Because FiftyOneTorchDataset is built from a FiftyOne view, any filtering, sorting, or tagging operation in FiftyOne automatically becomes a training or validation split — no data duplication needed. The cells above illustrate this:
# 90/10 random split from all non-test samples
not_test = mnist.match_tags("test", bool=False)
four.random_split(not_test, {"train": 0.9, "validation": 0.1})
# Or scope training to a curated subset
subset = mnist.select(selected_ids)
You can pass any view — filtered by tag, label, quality metric, or brain-run result — straight into create_dataloaders() without changing the training script.
Writing Predictions Back to FiftyOne#
During evaluation, the script writes per-sample predictions back to the dataset:
fo_predictions = [
fo.Classification(
label=utils.mnist_index_to_label_string(np.argmax(sample_logits)),
logits=sample_logits,
)
for sample_logits in prediction.detach().cpu().numpy()
]
samples.set_values("predictions", fo_predictions)
samples.save()
After training, you can open the FiftyOne App and immediately browse predictions, filter by confidence, and inspect misclassified samples — all within the same workflow.
Evaluation with FiftyOne#
Once predictions are stored, FiftyOne’s built-in evaluation API runs directly on the test split:
results = dataset.match_tags("test").evaluate_classifications(
"predictions",
gt_field="ground_truth",
eval_key="eval",
classes=classes,
k=3,
)
results.print_report(classes=classes)
Results are persisted under the eval key, making per-class metrics and top-k accuracy available for review in the App at any time.
Run in Google Colab
View source on GitHub