Source2Tensors Demo#

Download this notebook from GitHub

Open In Colab

Preliminaries#

Imports#

[1]:
import functools
from pathlib import Path

import holoviews as hv
import panel as pn

from bridge.display.vision import Panel
from bridge.primitives.element.data.cache_mechanism import CacheMechanism
from bridge.primitives.element.data.uri_components import URIComponents
from bridge.utils import pmap

hv.extension("bokeh")
pn.extension()

TMP_NOTEBOOK_ROOT = Path("/tmp/bridge-ds/tutorials")
%opts magic unavailable (pyparsing cannot be imported)
%compositor magic unavailable (pyparsing cannot be imported)

Load Dataset#

[2]:
from bridge.providers.vision import Coco2017Detection

root_dir = TMP_NOTEBOOK_ROOT / "coco"

provider = Coco2017Detection(root_dir, split="val", img_source="download")
ds = provider.build_dataset()
ds
Annotations file /tmp/bridge-ds/tutorials/coco/annotations/instances_val2017.json already exists, skipping download.
Downloading images...
Downloading http://images.cocodataset.org/zips/val2017.zip to /tmp/bridge-ds/tutorials/coco/val2017.zip
Extracting /tmp/bridge-ds/tutorials/coco/val2017.zip to /tmp/bridge-ds/tutorials/coco
loading annotations into memory...
Done (t=0.92s)
creating index...
index created!
[2]:
Dataset: {'n_samples': 5000, 'n_bbox': 36781, 'n_image': 5000}

Demo: Data Processing - From Sources to Pytorch#

In this demo, we’ll be working with COCO-val. We began by loading it into Bridge Dataset, and we will proceed by applying data augmentations, visualizing the results, and once we’re satisfied with our augmentation pipeline we will finally convert this augmented Dataset into a training-ready PyTorch dataset.

We want to apply data augmentations on our Dataset before feeding it to our model for training. For this purpose, we have ds.transform_samples() which accepts SampleTransform objects. One of such SampleTransforms is TorchvisionV2Transform, our adapter which allows users to use torchvision’s v2 transforms with Dataset.

First, let’s define our transforms:

[3]:
from torchvision.transforms import v2

from bridge.primitives.sample.transform.vision import TorchvisionV2Transform

transforms = TorchvisionV2Transform(
    [
        v2.RandomHorizontalFlip(p=0.5),
        v2.RandomAffine(
            translate=(0.0625, 0.0625),
            scale=(0.9, 1.1),
            degrees=45,
        ),
        v2.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.01),
        v2.RandomResizedCrop((448, 448), scale=(0.01, 0.05)),
        v2.SanitizeBoundingBoxes(),
    ],
)

Let’s apply these transforms on our Dataset using transform_samples(). Note that transform_samples() adheres to the Sample API, not the Table API. This means that behind the scenes we iterate over all samples, rather than using a vectorized pandas implementation.

[4]:
import random
import warnings

import numpy as np

random.seed(0)
np.random.seed(0)

# Cache the resulting augmented images into a local path ${TMP_NOTEBOOK_ROOT}/ds_augs
caches = {
    "image": CacheMechanism(URIComponents.from_str(str(TMP_NOTEBOOK_ROOT / "ds_augs"))),
}

# Function responsible for iterating and applying the SampleTransform.
# It could be as simple as `map`, but we can use a multi-process variant for better performance.
# map_fn = functools.partial(pmap, backend="concurrent",n_jobs=0, progress_bar=False)
map_fn = functools.partial(pmap, backend="dataloader", progress_bar=False)

with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=UserWarning)  # hide "low contrast" warnings
    ds_augs = ds.transform_samples(
        transform=transforms, map_fn=map_fn, cache_mechanisms=caches, display_engine=Panel(bbox_format="xywh")
    )

After a few seconds, we have our augmented dataset. By observing the samples table we can see that the new images were saved locally to our directory of choice:

[5]:
ds_augs.samples.head(3)
[5]:
element_type data category license file_name coco_url height width date_captured flickr_url
sample_id element_id
139 139_img image /tmp/bridge-ds/tutorials/ds_augs/139_img.jpg image 2.0 000000000139.jpg http://images.cocodataset.org/val2017/00000000... 426.0 640.0 2013-11-21 01:34:01 http://farm9.staticflickr.com/8035/8024364858_...
285 285_img image /tmp/bridge-ds/tutorials/ds_augs/285_img.jpg image 4.0 000000000285.jpg http://images.cocodataset.org/val2017/00000000... 640.0 586.0 2013-11-18 13:09:47 http://farm8.staticflickr.com/7434/9138147604_...
632 632_img image /tmp/bridge-ds/tutorials/ds_augs/632_img.jpg image 3.0 000000000632.jpg http://images.cocodataset.org/val2017/00000000... 483.0 640.0 2013-11-20 21:14:01 http://farm2.staticflickr.com/1241/1243324748_...

And we can browse this augmented Dataset just like the original one:

[6]:
ds_augs.show()
[6]:

By manually browsing our Dataset, we can see that we completely mis-parameterized the RandomCrop augmentation - the crops are too small!

We can confirm this by extracting statistics over the remaining annotations:

[7]:
print(f"num annotations ds: {len(ds.annotations)}")
print(f"num annotations ds_augs: {len(ds_augs.annotations)}")

n_annotations_per_image_ds = (
    ds.annotations.groupby("sample_id")
    .size()
    # samples with no annotations won't have a group in the groupby
    .reindex(ds.samples.index.get_level_values("sample_id"), fill_value=0)
    .mean()
)
n_annotations_per_image_ds_augs = (
    ds_augs.annotations.groupby("sample_id")
    .size()
    .reindex(ds_augs.samples.index.get_level_values("sample_id"), fill_value=0)
    .mean()
)

print(f"mean num annotations per image ds: {n_annotations_per_image_ds}")
print(f"mean num annotations per image ds_augs: {n_annotations_per_image_ds_augs}")
num annotations ds: 36781
num annotations ds_augs: 3034
mean num annotations per image ds: 7.3562
mean num annotations per image ds_augs: 0.6068

We can see that the numbers tell the same story - we’ve lost many annotations. Let’s fix the transform parameters and reapply them:

[8]:
transforms = TorchvisionV2Transform(
    [
        v2.RandomHorizontalFlip(p=0.5),
        v2.RandomAffine(
            translate=(0.0625, 0.0625),
            scale=(0.9, 1.1),
            degrees=45,
        ),
        v2.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.01),
        v2.RandomResizedCrop((448, 448), scale=(0.3, 1.0)),
        v2.SanitizeBoundingBoxes(),
    ]
)

with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=UserWarning)  # hide "low contrast" warnings
    ds_augs = ds.transform_samples(
        transform=transforms, map_fn=map_fn, cache_mechanisms=caches, display_engine=Panel(bbox_format="xywh")
    )
[9]:
print(f"num annotations ds: {len(ds.annotations)}")
print(f"num annotations ds_augs: {len(ds_augs.annotations)}")

n_annotations_per_image_ds = (
    ds.annotations.groupby("sample_id")
    .size()
    .reindex(ds.samples.index.get_level_values("sample_id"), fill_value=0)
    .mean()
)
n_annotations_per_image_ds_augs = (
    ds_augs.annotations.groupby("sample_id")
    .size()
    .reindex(ds_augs.samples.index.get_level_values("sample_id"), fill_value=0)
    .mean()
)

print(f"mean num annotations per image ds: {n_annotations_per_image_ds}")
print(f"mean num annotations per image ds_augs: {n_annotations_per_image_ds_augs}")
num annotations ds: 36781
num annotations ds_augs: 22723
mean num annotations per image ds: 7.3562
mean num annotations per image ds_augs: 4.5446

This time, we’ve lost significantly less annotations to the random crop operation. We can observe the samples manually as well, if we’d like:

[10]:
ds_augs.show()
[10]:

At this point, we’re satisfied with our augmented Dataset. The next step is converting this dataset into viable input for a deep learning model - that is, converting the dataset to tensors. For our engine of choice, we’ll demonstrate with PyTorch, but this technique should generalize to other deep learning frameworks just as well.

NOTE: up until this point of the tutorial, we have no actual dependency on which deep learning framework we were using. All of this works just as well if our DL framework of choice were Keras or TensorFlow.

The transformation into tensors works exactly as before, with transform_data:

[11]:
import warnings

to_tensor_transform = TorchvisionV2Transform(
    [
        v2.RGB(),
        v2.ToTensor(),
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)


with warnings.catch_warnings():
    warnings.filterwarnings(
        "ignore"
    )  # Applying A.ToRGB() on an image that is already RGB throws a warning, we'll filter these out
    ds_tensors = ds_augs.transform_samples(
        transform=to_tensor_transform,
        map_fn=map_fn,
        display_engine=None,  # the output is not images anymore, so a Panel DisplayEngine won't work
        cache_mechanisms={"image": CacheMechanism(URIComponents.from_str(str(TMP_NOTEBOOK_ROOT / "ds_tensors")))},
    )
/home/docs/checkouts/readthedocs.org/user_builds/bridge-ds/envs/latest/lib/python3.11/site-packages/torchvision/transforms/v2/_deprecated.py:42: UserWarning: The transform `ToTensor()` is deprecated and will be removed in a future release. Instead, please use `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`.Output is equivalent up to float precision.
  warnings.warn(

Since we can’t use Panel rendering anymore, let’s just use a few prints to make sure the data is in our required format:

[12]:
img_data = ds_tensors.iget(0).data
print("shape:", img_data.shape, "\n")
print(img_data)
shape: torch.Size([3, 448, 448])

tensor([[[ 0.8104,  0.8104,  0.8104,  ..., -2.1179, -2.1179, -2.1179],
         [ 0.8104,  0.8104,  0.8104,  ..., -2.1179, -2.1179, -2.1179],
         [ 0.7762,  0.7762,  0.7762,  ..., -2.1179, -2.1179, -2.1179],
         ...,
         [ 0.6049,  0.5022,  0.4508,  ..., -2.1179, -2.1179, -2.1179],
         [ 0.6049,  0.4508,  0.3994,  ..., -2.1179, -2.1179, -2.1179],
         [ 0.5878,  0.4337,  0.3309,  ..., -2.1179, -2.1179, -2.1179]],

        [[ 0.1527,  0.1527,  0.1527,  ..., -2.0357, -2.0357, -2.0357],
         [ 0.1527,  0.1527,  0.1527,  ..., -2.0357, -2.0357, -2.0357],
         [ 0.1702,  0.1702,  0.1702,  ..., -2.0357, -2.0357, -2.0357],
         ...,
         [ 0.5378,  0.4328,  0.3627,  ..., -2.0357, -2.0357, -2.0357],
         [ 0.5203,  0.3627,  0.2402,  ..., -2.0357, -2.0357, -2.0357],
         [ 0.5028,  0.3102,  0.1702,  ..., -2.0357, -2.0357, -2.0357]],

        [[-0.9853, -0.9853, -0.9853,  ..., -1.8044, -1.8044, -1.8044],
         [-0.9853, -0.9853, -0.9853,  ..., -1.8044, -1.8044, -1.8044],
         [-0.9853, -0.9853, -0.9853,  ..., -1.8044, -1.8044, -1.8044],
         ...,
         [ 0.0256, -0.0790, -0.1835,  ..., -1.8044, -1.8044, -1.8044],
         [-0.0267, -0.1835, -0.2881,  ..., -1.8044, -1.8044, -1.8044],
         [-0.0441, -0.2184, -0.3578,  ..., -1.8044, -1.8044, -1.8044]]])

The last step is to convert this ds_tensors to a torch Dataset. We will do this using PytorchEngineDataset object, which directly inherits from torch.utils.data.Dataset:

[13]:
import torch

from bridge.engines.pytorch import PytorchEngineDataset

ds_pytorch = PytorchEngineDataset(ds_tensors)

print(isinstance(ds_pytorch, torch.utils.data.Dataset))
print(type(ds_pytorch))
print(len(ds_pytorch))
True
<class 'bridge.engines.pytorch.PytorchEngineDataset'>
5000
[14]:
item = ds_pytorch[0]

img = item["image"][0]
bboxes = item["bbox"]
print("Image: ")
print(img, img.shape)
print()
print("Bbox Classes: ")
print([bbox.class_label for bbox in bboxes])
print()
print("Bbox Coords: ")
print([bbox.coords for bbox in bboxes])
Image:
tensor([[[ 0.8104,  0.8104,  0.8104,  ..., -2.1179, -2.1179, -2.1179],
         [ 0.8104,  0.8104,  0.8104,  ..., -2.1179, -2.1179, -2.1179],
         [ 0.7762,  0.7762,  0.7762,  ..., -2.1179, -2.1179, -2.1179],
         ...,
         [ 0.6049,  0.5022,  0.4508,  ..., -2.1179, -2.1179, -2.1179],
         [ 0.6049,  0.4508,  0.3994,  ..., -2.1179, -2.1179, -2.1179],
         [ 0.5878,  0.4337,  0.3309,  ..., -2.1179, -2.1179, -2.1179]],

        [[ 0.1527,  0.1527,  0.1527,  ..., -2.0357, -2.0357, -2.0357],
         [ 0.1527,  0.1527,  0.1527,  ..., -2.0357, -2.0357, -2.0357],
         [ 0.1702,  0.1702,  0.1702,  ..., -2.0357, -2.0357, -2.0357],
         ...,
         [ 0.5378,  0.4328,  0.3627,  ..., -2.0357, -2.0357, -2.0357],
         [ 0.5203,  0.3627,  0.2402,  ..., -2.0357, -2.0357, -2.0357],
         [ 0.5028,  0.3102,  0.1702,  ..., -2.0357, -2.0357, -2.0357]],

        [[-0.9853, -0.9853, -0.9853,  ..., -1.8044, -1.8044, -1.8044],
         [-0.9853, -0.9853, -0.9853,  ..., -1.8044, -1.8044, -1.8044],
         [-0.9853, -0.9853, -0.9853,  ..., -1.8044, -1.8044, -1.8044],
         ...,
         [ 0.0256, -0.0790, -0.1835,  ..., -1.8044, -1.8044, -1.8044],
         [-0.0267, -0.1835, -0.2881,  ..., -1.8044, -1.8044, -1.8044],
         [-0.0441, -0.2184, -0.3578,  ..., -1.8044, -1.8044, -1.8044]]]) torch.Size([3, 448, 448])

Bbox Classes:
[ClassLabel(class_idx=64, class_name='64'), ClassLabel(class_idx=72, class_name='72'), ClassLabel(class_idx=72, class_name='72'), ClassLabel(class_idx=62, class_name='62'), ClassLabel(class_idx=62, class_name='62'), ClassLabel(class_idx=62, class_name='62'), ClassLabel(class_idx=62, class_name='62'), ClassLabel(class_idx=1, class_name='1'), ClassLabel(class_idx=1, class_name='1'), ClassLabel(class_idx=78, class_name='78'), ClassLabel(class_idx=82, class_name='82'), ClassLabel(class_idx=84, class_name='84'), ClassLabel(class_idx=84, class_name='84'), ClassLabel(class_idx=85, class_name='85'), ClassLabel(class_idx=86, class_name='86'), ClassLabel(class_idx=86, class_name='86'), ClassLabel(class_idx=62, class_name='62')]

Bbox Coords:
[array([331.7733872 ,  99.10415502, 108.91498417, 126.12165371]), array([104.755404  , 161.9056747 ,  84.27551929, 153.33701666]), array([ 25.95473018, 175.99640154,  90.17015774, 148.70905608]), array([168.26144467, 162.25379621,  50.36785358, 118.11670962]), array([ 57.00911532, 180.76176506,  27.11477246,  20.74382798]), array([155.29343776,  67.70910096,  87.59939243, 200.7845585 ]), array([125.35288731, 102.70737313,  24.27962839,  52.30976759]), array([279.20733602, 119.69791172,  20.06056088,  25.18386227]), array([251.14064792,  79.81411973,  44.12225454, 152.73697367]), array([405.13882289, 235.77151534,  25.28625761,  65.76908602]), array([415.36066702, 237.39654691,  23.74134854,  66.43655065]), array([188.66284214,  18.73281143,  20.30233055,  33.09251707]), array([341.53912373, 248.17790794,  59.46803984, 130.95819458]), array([ 93.51015928, 161.28961548,  17.43064717,  33.41949108]), array([166.40234985, 161.66819297,  13.50503784,  19.3068486 ]), array([ 75.60586367, 152.05378451,  14.42549697,  25.08760577]), array([ 63.69593355, 172.71200281, 161.87381353, 150.09436995])]

As we can see, every item in PytorchEngineDataset is a dictionary with string keys that match etypes (in our case, ‘image’ and ‘bbox’); the values are lists of objects where the image is a torch.Tensor object, and the bboxes are a class we created, but you can use whatever you like.