Run in Google Colab | View source on GitHub | Download notebook |
Training and Evaluating FiftyOne Datasets with Detectron2¶
FiftyOne has all of the building blocks necessary to develop high-quality datasets to train your models, as well as advanced model evaluation capabilities. To make use of these, FiftyOne easily integrates with your existing model training and inference pipelines. In this walktrhough we’ll cover how you can use your FiftyOne datasets to train a model with Detectron2, Facebook AI Reasearch’s library for detection and segmentation algorithms.
This walkthrough is based off of the official Detectron2 tutorial, augmented to load data to and from FiftyOne.
Specifically, this walkthrough covers:
Loading a dataset from the FiftyOne Zoo, and splitting it into training/validation
Initializing a segmentation model from the detectron2 model zoo
Loading ground truth annotations from a FiftyOne dataset into a detectron2 model training pipeline and training the model
Loading predictions from a detectron2 model into a FiftyOne dataset
Evaluating model predictions in FiftyOne
So, what’s the takeaway?
By writing two simple functions, you can integrate FiftyOne into your Detectron2 model training and inference pipelines.
Setup¶
To get started, you need to install FiftyOne and detectron2:
[ ]:
!pip install fiftyone
[ ]:
import fiftyone as fo
import fiftyone.zoo as foz
[ ]:
!python -m pip install pyyaml==5.1
# Detectron2 has not released pre-built binaries for the latest pytorch (https://github.com/facebookresearch/detectron2/issues/4053)
# so we install from source instead. This takes a few minutes.
!python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
# Install pre-built detectron2 that matches pytorch version, if released:
# See https://detectron2.readthedocs.io/tutorials/install.html for instructions
#!pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/{CUDA_VERSION}/{TORCH_VERSION}/index.html
[ ]:
import torch, detectron2
!nvcc --version
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)
print("detectron2:", detectron2.__version__)
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Oct_12_20:09:46_PDT_2020
Cuda compilation tools, release 11.1, V11.1.105
Build cuda_11.1.TC455_06.29190527_0
torch: 1.12 ; cuda: cu113
detectron2: 0.6
[ ]:
# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()
# import some common libraries
import numpy as np
import os, cv2
# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog, DatasetCatalog
Train on a FiftyOne dataset¶
In this section, we show how to use a custom FiftyOne Dataset to train a detectron2 model. We’ll train a license plate segmentation model from an existing model pre-trained on COCO dataset, available in detectron2’s model zoo.
Since the COCO dataset doesn’t have a “Vehicle registration plates” category, we will be using segmentations of license plates from the Open Images v6 dataset in the FiftyOne Dataset Zoo to train the model to recognize this new category.
Prepare the dataset¶
For this example, we will just use some of the samples from the official “validation” split of the dataset. To improve model performance, we could always add in more data from the official “train” split as well but that will take longer to train so we’ll just stick to the “validation” split for this walkthrough.
[ ]:
dataset = foz.load_zoo_dataset(
"open-images-v6",
split="validation",
classes=["Vehicle registration plate"],
label_types=["segmentations"],
label_field="segmentations",
)
Specifying a classes
when downloading a dataset from the zoo will ensure that only samples with one of the given classes will be present. However, these samples may still contain other labels, so we can use the powerful filtering capability of FiftyOne to easily keep only the “Vehicle registration plate” labels. We will also untag these samples as “validation” and create our own split out of them.
[ ]:
from fiftyone import ViewField as F
# Remove other classes and existing tags
dataset.filter_labels("segmentations", F("label") == "Vehicle registration plate").save()
dataset.untag_samples("validation")
[ ]:
import fiftyone.utils.random as four
four.random_split(dataset, {"train": 0.8, "val": 0.2})
Next we will register the FiftyOne dataset to detectron2, following the detectron2 custom dataset tutorial. Here, the dataset is in its custom format, therefore we write a function to parse it and prepare it into detectron2’s standard format.
Note: In this example, we are specifically parsing the segmentations into bounding boxes and polylines. This function may require tweaks depending on the model being trained and the data it expects.
[ ]:
from detectron2.structures import BoxMode
def get_fiftyone_dicts(samples):
samples.compute_metadata()
dataset_dicts = []
for sample in samples.select_fields(["id", "filepath", "metadata", "segmentations"]):
height = sample.metadata["height"]
width = sample.metadata["width"]
record = {}
record["file_name"] = sample.filepath
record["image_id"] = sample.id
record["height"] = height
record["width"] = width
objs = []
for det in sample.segmentations.detections:
tlx, tly, w, h = det.bounding_box
bbox = [int(tlx*width), int(tly*height), int(w*width), int(h*height)]
fo_poly = det.to_polyline()
poly = [(x*width, y*height) for x, y in fo_poly.points[0]]
poly = [p for x in poly for p in x]
obj = {
"bbox": bbox,
"bbox_mode": BoxMode.XYWH_ABS,
"segmentation": [poly],
"category_id": 0,
}
objs.append(obj)
record["annotations"] = objs
dataset_dicts.append(record)
return dataset_dicts
for d in ["train", "val"]:
view = dataset.match_tags(d)
DatasetCatalog.register("fiftyone_" + d, lambda view=view: get_fiftyone_dicts(view))
MetadataCatalog.get("fiftyone_" + d).set(thing_classes=["vehicle_registration_plate"])
metadata = MetadataCatalog.get("fiftyone_train")
To verify the dataset is in correct format, let’s visualize the annotations of the training set:
[ ]:
dataset_dicts = get_fiftyone_dicts(dataset.match_tags("train"))
ids = [dd["image_id"] for dd in dataset_dicts]
view = dataset.select(ids)
session = fo.launch_app(view)