{ "cells": [ { "cell_type": "markdown", "id": "0", "metadata": {}, "source": [ "# Preliminaries" ] }, { "cell_type": "markdown", "id": "1", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "id": "2", "metadata": {}, "outputs": [], "source": [ "import functools\n", "from pathlib import Path\n", "\n", "import holoviews as hv\n", "import panel as pn\n", "\n", "from bridge.display.vision import Panel\n", "from bridge.primitives.element.data.cache_mechanism import CacheMechanism\n", "from bridge.primitives.element.data.uri_components import URIComponents\n", "from bridge.utils import pmap\n", "\n", "hv.extension(\"bokeh\")\n", "pn.extension()\n", "\n", "TMP_NOTEBOOK_ROOT = Path(\"/tmp/bridge-ds/tutorials\")" ] }, { "cell_type": "markdown", "id": "3", "metadata": {}, "source": [ "## Load Dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "4", "metadata": {}, "outputs": [], "source": [ "from bridge.providers.vision import Coco2017Detection\n", "\n", "root_dir = TMP_NOTEBOOK_ROOT / \"coco\"\n", "\n", "provider = Coco2017Detection(root_dir, split=\"val\", img_source=\"download\")\n", "ds = provider.build_dataset()\n", "ds" ] }, { "attachments": {}, "cell_type": "markdown", "id": "5", "metadata": {}, "source": [ "# Demo: Data Processing - From Sources to Pytorch\n", "\n", "In this demo, we'll be working with COCO-val. We began by loading it into Bridge Dataset, and we will proceed by applying data augmentations, visualizing the results, and once we're satisfied with our augmentation pipeline we will finally convert this augmented Dataset into a training-ready PyTorch dataset." ] }, { "cell_type": "markdown", "id": "6", "metadata": {}, "source": [ "### Applying Data Augmentations\n", "We want to apply data augmentations on our Dataset before feeding it to our model for training.\n", "For this purpose, we have `ds.transform_samples()` which accepts **SampleTransform** objects. One of such SampleTransforms is **TorchvisionV2Transform**, our adapter which allows users to use [torchvision's v2 transforms](https://docs.pytorch.org/vision/stable/transforms.html#v1-or-v2-which-one-should-i-use) with Dataset.\n", "\n", "First, let's define our transforms:" ] }, { "cell_type": "code", "execution_count": null, "id": "7", "metadata": {}, "outputs": [], "source": [ "from torchvision.transforms import v2\n", "\n", "from bridge.primitives.sample.transform.vision import TorchvisionV2Transform\n", "\n", "transforms = TorchvisionV2Transform(\n", " [\n", " v2.RandomHorizontalFlip(p=0.5),\n", " v2.RandomAffine(\n", " translate=(0.0625, 0.0625),\n", " scale=(0.9, 1.1),\n", " degrees=45,\n", " ),\n", " v2.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.01),\n", " v2.RandomResizedCrop((448, 448), scale=(0.01, 0.05)),\n", " v2.SanitizeBoundingBoxes(),\n", " ],\n", ")" ] }, { "cell_type": "markdown", "id": "8", "metadata": {}, "source": [ "Let's apply these transforms on our Dataset using `transform_samples()`. Note that `transform_samples()` adheres to the Sample API, _not_ the Table API. This means that behind the scenes we iterate over all samples, rather than using a vectorized pandas implementation." ] }, { "cell_type": "code", "execution_count": null, "id": "9", "metadata": {}, "outputs": [], "source": [ "import random\n", "import warnings\n", "\n", "import numpy as np\n", "\n", "random.seed(0)\n", "np.random.seed(0)\n", "\n", "# Cache the resulting augmented images into a local path ${TMP_NOTEBOOK_ROOT}/ds_augs\n", "caches = {\n", " \"image\": CacheMechanism(URIComponents.from_str(str(TMP_NOTEBOOK_ROOT / \"ds_augs\"))),\n", "}\n", "\n", "# Function responsible for iterating and applying the SampleTransform.\n", "# It could be as simple as `map`, but we can use a multi-process variant for better performance.\n", "# map_fn = functools.partial(pmap, backend=\"concurrent\",n_jobs=0, progress_bar=False)\n", "map_fn = functools.partial(pmap, backend=\"dataloader\", progress_bar=False)\n", "\n", "with warnings.catch_warnings():\n", " warnings.simplefilter(\"ignore\", category=UserWarning) # hide \"low contrast\" warnings\n", " ds_augs = ds.transform_samples(\n", " transform=transforms, map_fn=map_fn, cache_mechanisms=caches, display_engine=Panel(bbox_format=\"xywh\")\n", " )" ] }, { "cell_type": "markdown", "id": "10", "metadata": {}, "source": [ "After a few seconds, we have our augmented dataset. By observing the `samples` table we can see that the new images were saved locally to our directory of choice:" ] }, { "cell_type": "code", "execution_count": null, "id": "11", "metadata": {}, "outputs": [], "source": [ "ds_augs.samples.head(3)" ] }, { "cell_type": "markdown", "id": "12", "metadata": {}, "source": [ "And we can browse this augmented Dataset just like the original one:" ] }, { "cell_type": "code", "execution_count": null, "id": "13", "metadata": {}, "outputs": [], "source": [ "ds_augs.show()" ] }, { "cell_type": "markdown", "id": "14", "metadata": {}, "source": [ "By manually browsing our Dataset, we can see that we completely mis-parameterized the `RandomCrop` augmentation - the crops are too small!\n", "\n", "We can confirm this by extracting statistics over the remaining annotations:" ] }, { "cell_type": "code", "execution_count": null, "id": "15", "metadata": {}, "outputs": [], "source": [ "print(f\"num annotations ds: {len(ds.annotations)}\")\n", "print(f\"num annotations ds_augs: {len(ds_augs.annotations)}\")\n", "\n", "n_annotations_per_image_ds = (\n", " ds.annotations.groupby(\"sample_id\")\n", " .size()\n", " # samples with no annotations won't have a group in the groupby\n", " .reindex(ds.samples.index.get_level_values(\"sample_id\"), fill_value=0)\n", " .mean()\n", ")\n", "n_annotations_per_image_ds_augs = (\n", " ds_augs.annotations.groupby(\"sample_id\")\n", " .size()\n", " .reindex(ds_augs.samples.index.get_level_values(\"sample_id\"), fill_value=0)\n", " .mean()\n", ")\n", "\n", "print(f\"mean num annotations per image ds: {n_annotations_per_image_ds}\")\n", "print(f\"mean num annotations per image ds_augs: {n_annotations_per_image_ds_augs}\")" ] }, { "cell_type": "markdown", "id": "16", "metadata": {}, "source": [ "We can see that the numbers tell the same story - we've lost many annotations. Let's fix the transform parameters and reapply them:" ] }, { "cell_type": "code", "execution_count": null, "id": "17", "metadata": {}, "outputs": [], "source": [ "transforms = TorchvisionV2Transform(\n", " [\n", " v2.RandomHorizontalFlip(p=0.5),\n", " v2.RandomAffine(\n", " translate=(0.0625, 0.0625),\n", " scale=(0.9, 1.1),\n", " degrees=45,\n", " ),\n", " v2.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.01),\n", " v2.RandomResizedCrop((448, 448), scale=(0.3, 1.0)),\n", " v2.SanitizeBoundingBoxes(),\n", " ]\n", ")\n", "\n", "with warnings.catch_warnings():\n", " warnings.simplefilter(\"ignore\", category=UserWarning) # hide \"low contrast\" warnings\n", " ds_augs = ds.transform_samples(\n", " transform=transforms, map_fn=map_fn, cache_mechanisms=caches, display_engine=Panel(bbox_format=\"xywh\")\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "18", "metadata": {}, "outputs": [], "source": [ "print(f\"num annotations ds: {len(ds.annotations)}\")\n", "print(f\"num annotations ds_augs: {len(ds_augs.annotations)}\")\n", "\n", "n_annotations_per_image_ds = (\n", " ds.annotations.groupby(\"sample_id\")\n", " .size()\n", " .reindex(ds.samples.index.get_level_values(\"sample_id\"), fill_value=0)\n", " .mean()\n", ")\n", "n_annotations_per_image_ds_augs = (\n", " ds_augs.annotations.groupby(\"sample_id\")\n", " .size()\n", " .reindex(ds_augs.samples.index.get_level_values(\"sample_id\"), fill_value=0)\n", " .mean()\n", ")\n", "\n", "print(f\"mean num annotations per image ds: {n_annotations_per_image_ds}\")\n", "print(f\"mean num annotations per image ds_augs: {n_annotations_per_image_ds_augs}\")" ] }, { "cell_type": "markdown", "id": "19", "metadata": {}, "source": [ "This time, we've lost significantly less annotations to the random crop operation. We can observe the samples manually as well, if we'd like:" ] }, { "cell_type": "code", "execution_count": null, "id": "20", "metadata": {}, "outputs": [], "source": [ "ds_augs.show()" ] }, { "cell_type": "markdown", "id": "21", "metadata": {}, "source": [ "### Converting to tensors\n", "At this point, we're satisfied with our augmented Dataset. The next step is converting this dataset into viable input for a deep learning model - that is, converting the dataset to tensors. For our engine of choice, we'll demonstrate with PyTorch, but this technique should generalize to other deep learning frameworks just as well.\n", "\n", "NOTE: up until this point of the tutorial, we have no actual dependency on which deep learning framework we were using. All of this works just as well if our DL framework of choice were Keras or TensorFlow.\n", "\n", "The transformation into tensors works exactly as before, with `transform_data`:" ] }, { "cell_type": "code", "execution_count": null, "id": "22", "metadata": {}, "outputs": [], "source": [ "import warnings\n", "\n", "to_tensor_transform = TorchvisionV2Transform(\n", " [\n", " v2.RGB(),\n", " v2.ToTensor(),\n", " v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", " ]\n", ")\n", "\n", "\n", "with warnings.catch_warnings():\n", " warnings.filterwarnings(\n", " \"ignore\"\n", " ) # Applying A.ToRGB() on an image that is already RGB throws a warning, we'll filter these out\n", " ds_tensors = ds_augs.transform_samples(\n", " transform=to_tensor_transform,\n", " map_fn=map_fn,\n", " display_engine=None, # the output is not images anymore, so a Panel DisplayEngine won't work\n", " cache_mechanisms={\"image\": CacheMechanism(URIComponents.from_str(str(TMP_NOTEBOOK_ROOT / \"ds_tensors\")))},\n", " )" ] }, { "cell_type": "markdown", "id": "23", "metadata": {}, "source": [ "Since we can't use `Panel` rendering anymore, let's just use a few prints to make sure the data is in our required format:" ] }, { "cell_type": "code", "execution_count": null, "id": "24", "metadata": {}, "outputs": [], "source": [ "img_data = ds_tensors.iget(0).data\n", "print(\"shape:\", img_data.shape, \"\\n\")\n", "print(img_data)" ] }, { "cell_type": "markdown", "id": "25", "metadata": {}, "source": [ "The last step is to convert this `ds_tensors` to a torch Dataset. We will do this using `PytorchEngineDataset` object, which directly inherits from `torch.utils.data.Dataset`:" ] }, { "cell_type": "code", "execution_count": null, "id": "26", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "from bridge.engines.pytorch import PytorchEngineDataset\n", "\n", "ds_pytorch = PytorchEngineDataset(ds_tensors)\n", "\n", "print(isinstance(ds_pytorch, torch.utils.data.Dataset))\n", "print(type(ds_pytorch))\n", "print(len(ds_pytorch))" ] }, { "cell_type": "code", "execution_count": null, "id": "27", "metadata": {}, "outputs": [], "source": [ "item = ds_pytorch[0]\n", "\n", "img = item[\"image\"][0]\n", "bboxes = item[\"bbox\"]\n", "print(\"Image: \")\n", "print(img, img.shape)\n", "print()\n", "print(\"Bbox Classes: \")\n", "print([bbox.class_label for bbox in bboxes])\n", "print()\n", "print(\"Bbox Coords: \")\n", "print([bbox.coords for bbox in bboxes])" ] }, { "cell_type": "markdown", "id": "28", "metadata": {}, "source": [ "As we can see, every item in `PytorchEngineDataset` is a dictionary with string keys that match etypes (in our case, 'image' and 'bbox'); the values are lists of objects where the image is a torch.Tensor object, and the bboxes are a class we created, but you can use whatever you like. " ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.13" } }, "nbformat": 4, "nbformat_minor": 5 }