From fb282fefd61644a8264496121a79840770987c20 Mon Sep 17 00:00:00 2001 From: Sieger Falkena Date: Wed, 18 Sep 2024 17:01:26 +0200 Subject: [PATCH] automatically shuffle every __iter__. Add tutorial notebook. --- docs/tutorials/visualizing_samples.ipynb | 374 +++++++++++++++++++++++ torchgeo/samplers/single.py | 18 +- 2 files changed, 388 insertions(+), 4 deletions(-) create mode 100644 docs/tutorials/visualizing_samples.ipynb diff --git a/docs/tutorials/visualizing_samples.ipynb b/docs/tutorials/visualizing_samples.ipynb new file mode 100644 index 00000000000..80be70dc7a5 --- /dev/null +++ b/docs/tutorials/visualizing_samples.ipynb @@ -0,0 +1,374 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Visualizing Samples\n", + "\n", + "This tutorial shows how to visualize and save the extent of your samples before and during training. In this particular example, we compare a vanilla RandomGeoSampler with one bounded by multiple ROI's and show how easy it is to gain insight on the distribution of your samples." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import tempfile\n", + "\n", + "from torch.utils.data import DataLoader\n", + "\n", + "from torchgeo.datasets import NAIP, stack_samples\n", + "from torchgeo.datasets.utils import download_url\n", + "from torchgeo.samplers import RandomGeoSampler" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "def run_epochs(dataset, sampler):\n", + " dataloader = DataLoader(\n", + " naip, sampler=sampler, batch_size=1, collate_fn=stack_samples, num_workers=0\n", + " )\n", + " fig, ax = plt.subplots()\n", + " num_epochs = 5\n", + " for epoch in range(num_epochs):\n", + " color = plt.cm.viridis(epoch / num_epochs)\n", + " sampler.chips.to_file(f'naip_chips_epoch_{epoch}')\n", + " ax = sampler.chips.plot(ax=ax, color=color)\n", + " for sample in dataloader:\n", + " pass\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Generate dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using downloaded and verified file: C:\\Users\\SIEGER~1.FAL\\AppData\\Local\\Temp\\naip\\m_3807511_ne_18_060_20181104.tif\n", + "Using downloaded and verified file: C:\\Users\\SIEGER~1.FAL\\AppData\\Local\\Temp\\naip\\m_3807512_sw_18_060_20180815.tif\n" + ] + } + ], + "source": [ + "naip_root = os.path.join(tempfile.gettempdir(), 'naip')\n", + "naip_url = (\n", + " 'https://naipeuwest.blob.core.windows.net/naip/v002/de/2018/de_060cm_2018/38075/'\n", + ")\n", + "tiles = ['m_3807511_ne_18_060_20181104.tif', 'm_3807512_sw_18_060_20180815.tif']\n", + "for tile in tiles:\n", + " download_url(naip_url + tile, naip_root)\n", + "\n", + "naip = NAIP(naip_root)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First we create the default sampler for our dataset (3 samples) and run it for 5 epochs and plot its results. Each color displays a different epoch, so we can see how the RandomGeoSampler has distributed it's samples for every epoch." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "generating samples... \n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 3/3 [00:00<00:00, 823.92it/s]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "sampler = RandomGeoSampler(naip, size=1000, length=3)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "generating samples... \n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 3/3 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "run_epochs(naip, sampler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we split our dataset by two bounding boxes and re-inspect the samples." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "from torchgeo.datasets import roi_split\n", + "from torchgeo.datasets.utils import BoundingBox\n", + "\n", + "rois = [\n", + " BoundingBox(440854, 442938, 4299766, 4301731, 0, np.inf),\n", + " BoundingBox(449070, 451194, 4289463, 4291746, 0, np.inf),\n", + "]\n", + "datasets = roi_split(naip, rois)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "combined = datasets[0] | datasets[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "generating samples... \n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 3/3 [00:00<00:00, 2997.36it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "generating samples... \n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 3/3 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "sampler = RandomGeoSampler(combined, size=1000, length=3)\n", + "run_epochs(combined, sampler)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cca", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 05bdf32f314..fcb4ed536f8 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -269,6 +269,18 @@ def __init__( self.chips = self.get_chips() + def __iter__(self) -> Iterator[BoundingBox]: + """Return the index of a dataset. + + Returns: + (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset + """ + self.refresh_samples() + for _, chip in self.chips.iterrows(): + yield BoundingBox( + chip.minx, chip.maxx, chip.miny, chip.maxy, chip.mint, chip.maxt + ) + def refresh_samples(self) -> None: """Refresh the samples in the sampler. @@ -284,6 +296,7 @@ def get_chips(self) -> GeoDataFrame: A GeoDataFrame containing the generated chips. """ chips = [] + print('generating samples... ') for _ in tqdm(range(self.length)): # Choose a random tile, weighted by area idx = torch.multinomial(self.areas, 1) @@ -305,7 +318,6 @@ def get_chips(self) -> GeoDataFrame: chips.append(chip) if chips: - print('creating geodataframe... ') chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) chips_gdf['fid'] = chips_gdf.index @@ -412,7 +424,6 @@ def get_chips(self) -> GeoDataFrame: chips.append(chip) if chips: - print('creating geodataframe... ') chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) chips_gdf['fid'] = chips_gdf.index @@ -468,6 +479,7 @@ def get_chips(self) -> GeoDataFrame: if self.shuffle: generator = torch.randperm + print('generating samples... ') chips = [] for idx in generator(self.length): minx, maxx, miny, maxy, mint, maxt = self.hits[idx].bounds @@ -480,11 +492,9 @@ def get_chips(self) -> GeoDataFrame: 'mint': mint, 'maxt': maxt, } - print('generating chip') self.length += 1 chips.append(chip) - print('creating geodataframe... ') chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) chips_gdf['fid'] = chips_gdf.index return chips_gdf