Run in Google Colab
|
View source on GitHub
|
|
[1]:
import fiftyone as fo
import fiftyone.zoo as foz
from fiftyone.utils.torch import FiftyOneTorchDataset
[2]:
import utils
[3]:
import torch
from torch.utils.data import DataLoader
import numpy as np
import torchvision.transforms.v2 as transforms
from torchvision import tv_tensors
import matplotlib.pyplot as plt
import matplotlib.patches as plt_patches
from PIL import Image
[4]:
torch.multiprocessing.set_start_method('forkserver')
Load Dataset#
[5]:
dataset = foz.load_zoo_dataset("quickstart", overwrite=True)
Overwriting existing directory '/home/jacobs/fiftyone/quickstart'
Downloading dataset to '/home/jacobs/fiftyone/quickstart'
Downloading dataset...
100% |████| 187.5Mb/187.5Mb [205.8ms elapsed, 0s remaining, 911.3Mb/s]
Extracting dataset...
Parsing dataset metadata
Found 200 samples
Dataset info written to '/home/jacobs/fiftyone/quickstart/info.json'
Loading existing dataset 'quickstart'. To reload from disk, either delete the existing dataset or provide a custom `dataset_name` to use
[6]:
# make sure it's persistent
print(dataset.persistent)
# if it's not, set this variable to True
if not dataset.persistent:
dataset.persistent = True
True
Do Your Data Centric Work#
[7]:
some_interesting_view = dataset.take(100)
Transferring to a Torch Dataset and the get_item Argument#
In order to transfer to a torch dataset, we most provide a function that transforms FiftyOne Samples into input for the model. We pass this function with the get_item argument. Let’s start with a very simple example that will help us understand what’s happening.
[8]:
# to best understand what's happening, let's first pass the identity function
def get_item_identity(x):
return x
[9]:
torch_dataset = some_interesting_view.to_torch(get_item_identity)
[10]:
result = torch_dataset[0]
print(type(result))
print(result['id'])
print(result['filepath'])
<class 'fiftyone.core.sample.Sample'>
67be7705acd35912dc493b1e
/home/jacobs/fiftyone/quickstart/data/001312.jpg
The get_item function can be anything that accepts a fiftyone Sample. Here is a simple example:
[11]:
def simple_get_item(sample):
return sample['id']
[12]:
torch_datset = some_interesting_view.to_torch(simple_get_item)
[13]:
# torch_dataset is now a fully functional torch_dataset.
print(torch_datset[0])
67be7705acd35912dc493b1e
[14]:
# torch_dataset has the same exact samples as some_interesting_view, down to the order.
assert [res for res in torch_datset] == some_interesting_view.values('id')
Write an actual get_item function#
[15]:
# let's write a standard detection get_item
augmentations = transforms.Compose([
transforms.CenterCrop(512),
transforms.ClampBoundingBoxes()
])
def get_item(sample):
res = {}
image = Image.open(sample['filepath'])
og_wh = np.array([image.width, image.height])
image = tv_tensors.Image(image)
detections = sample['ground_truth.detections']
if detections is None:
detections = []
detections_tensor = torch.tensor([detection['bounding_box'] for detection in detections]) if len(detections)>0 \
else torch.zeros((0,4))
res['box'] = tv_tensors.BoundingBoxes(detections_tensor * torch.tensor([*og_wh, *og_wh]),
format=tv_tensors.BoundingBoxFormat('XYWH'),
canvas_size=image.shape[-2:]
)
res['label'] = [detection['label'] for detection in detections]
res['id'] = sample['id']
image, res = augmentations(image, res)
return image, res
Visualizing the result#
[16]:
# This is also a good opportunity to debug your get_item in case you need to
[17]:
torch_dataset = some_interesting_view.to_torch(get_item)
[18]:
# run this a couple of times to look through samples in the dataset
random_index = np.random.randint(0, len(torch_dataset))
image, res = torch_dataset[random_index]
plt.title(res['id'])
plt.imshow(image.permute(1, 2, 0).numpy())
axes = plt.gca()
for b, l in zip(res['box'], res['label']):
rect = plt_patches.Rectangle((b[0], b[1]),
b[2], b[3],
linewidth=1, edgecolor='r', facecolor='none')
axes.add_patch(rect)
axes.annotate(l, rect.get_xy())
plt.show()
Creating a DataLoader#
FiftyOneTorchDatasets are compatible with torch DataLoaders, and can be used during training. Here is how you can create a DataLoader:
[19]:
# We need a new dataset object here. Once we've already sampled from the previous one, we have opened up a DB connection
# making the object unpickleable, and not suitable for multiproccessing use.
# utils.get_item_quickstart is the same get_item as above.
torch_dataset = some_interesting_view.to_torch(utils.get_item_quickstart)
"""
The code we are running is as follows:
def simple_collate_fn(batch):
return tuple(zip(*batch))
def create_dataloader_simple(torch_dataset):
dataloader = torch.utils.data.DataLoader(torch_dataset,
batch_size=5,
shuffle=True,
num_workers=2, # we are compatible with many workers
worker_init_fn=FiftyOneTorchDataset.worker_init, # this is required for the dataloader to work
collate_fn=simple_collate_fn)
We are running it from a separate file because Jupyter Notebooks are not compatible with the 'spawn' and 'forkserver' start methods
for code that is written *in* the notebook.
"""
dataloader = utils.create_dataloader_simple(torch_dataset)
[20]:
"""
Code we are running:
def ids_in_dataloader(dataloader):
# we can iterate over the dataset like this:
ids_seen = []
for images, results in dataloader:
assert len(images) == 5 # we are actually getting a batch of 5
ids_seen += [results[i]['id'] for i in range(len(results))]
return ids_seen
"""
ids_seen = utils.ids_in_dataloader(dataloader)
[21]:
# confirm we have seen each sample once
from collections import Counter
ids_with_counts = Counter(ids_seen)
assert set(ids_with_counts.keys()) == set(some_interesting_view.values('id'))
assert np.all(np.array(list(ids_with_counts.values())) == 1)
[22]:
# visualizing results, run this a couple of times to see different batches
iterator = iter(dataloader)
[23]:
image, res = next(iterator)
fig, axes = plt.subplots(1, len(image), figsize=(4 * len(image), 3))
for i, img in enumerate(image):
axes[i].set_title(res[i]['id'])
axes[i].imshow(img.permute(1, 2, 0).numpy())
for b, l in zip(res[i]['box'], res[i]['label']):
rect = plt_patches.Rectangle((b[0], b[1]),
b[2], b[3],
linewidth=1, edgecolor='r', facecolor='none')
axes[i].add_patch(rect)
axes[i].annotate(l, rect.get_xy())
plt.show()
Run in Google Colab
View source on GitHub