diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000000..d969f962b02 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "python.testing.pytestArgs": ["tests"], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} diff --git a/tests/data/samplers/filtering_4x4.feather b/tests/data/samplers/filtering_4x4.feather new file mode 100644 index 00000000000..305d37e4fa6 Binary files /dev/null and b/tests/data/samplers/filtering_4x4.feather differ diff --git a/tests/data/samplers/filtering_4x4/filtering_4x4.cpg b/tests/data/samplers/filtering_4x4/filtering_4x4.cpg new file mode 100644 index 00000000000..57decb48120 --- /dev/null +++ b/tests/data/samplers/filtering_4x4/filtering_4x4.cpg @@ -0,0 +1 @@ +ISO-8859-1 diff --git a/tests/data/samplers/filtering_4x4/filtering_4x4.dbf b/tests/data/samplers/filtering_4x4/filtering_4x4.dbf new file mode 100644 index 00000000000..499d67bcec4 Binary files /dev/null and b/tests/data/samplers/filtering_4x4/filtering_4x4.dbf differ diff --git a/tests/data/samplers/filtering_4x4/filtering_4x4.prj b/tests/data/samplers/filtering_4x4/filtering_4x4.prj new file mode 100644 index 00000000000..42fd4b91b78 --- /dev/null +++ b/tests/data/samplers/filtering_4x4/filtering_4x4.prj @@ -0,0 +1 @@ +PROJCS["NAD_1983_BC_Environment_Albers",GEOGCS["GCS_North_American_1983",DATUM["D_North_American_1983",SPHEROID["GRS_1980",6378137.0,298.257222101]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]],PROJECTION["Albers"],PARAMETER["False_Easting",1000000.0],PARAMETER["False_Northing",0.0],PARAMETER["Central_Meridian",-126.0],PARAMETER["Standard_Parallel_1",50.0],PARAMETER["Standard_Parallel_2",58.5],PARAMETER["Latitude_Of_Origin",45.0],UNIT["Meter",1.0]] diff --git a/tests/data/samplers/filtering_4x4/filtering_4x4.shp b/tests/data/samplers/filtering_4x4/filtering_4x4.shp new file mode 100644 index 00000000000..65606c26dd6 Binary files /dev/null and b/tests/data/samplers/filtering_4x4/filtering_4x4.shp differ diff --git a/tests/data/samplers/filtering_4x4/filtering_4x4.shx b/tests/data/samplers/filtering_4x4/filtering_4x4.shx new file mode 100644 index 00000000000..b2028e759e5 Binary files /dev/null and b/tests/data/samplers/filtering_4x4/filtering_4x4.shx differ diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 1416368098a..7cf54f69000 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -2,12 +2,15 @@ # Licensed under the MIT License. import math -from collections.abc import Iterator +import os from itertools import product +import geopandas as gpd import pytest from _pytest.fixtures import SubRequest +from geopandas import GeoDataFrame from rasterio.crs import CRS +from shapely.geometry import box from torch.utils.data import DataLoader from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples @@ -23,11 +26,23 @@ class CustomGeoSampler(GeoSampler): def __init__(self) -> None: - pass + self.chips = self.get_chips() - def __iter__(self) -> Iterator[BoundingBox]: + def get_chips(self) -> GeoDataFrame: + chips = [] for i in range(len(self)): - yield BoundingBox(i, i, i, i, i, i) + chips.append( + { + 'geometry': box(i, i, i, i), + 'minx': i, + 'miny': i, + 'maxx': i, + 'maxy': i, + 'mint': i, + 'maxt': i, + } + ) + return GeoDataFrame(chips, crs='3005') def __len__(self) -> int: return 2 @@ -43,6 +58,17 @@ def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]: return {'index': query} +class CustomGeoDatasetSITS(GeoDataset): + def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None: + super().__init__() + self._crs = crs + self.res = res + self.return_as_ts = True + + def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]: + return {'index': query} + + class TestGeoSampler: @pytest.fixture(scope='class') def dataset(self) -> CustomGeoDataset: @@ -54,6 +80,14 @@ def dataset(self) -> CustomGeoDataset: def sampler(self) -> CustomGeoSampler: return CustomGeoSampler() + @pytest.fixture(scope='class') + def datadir(self) -> str: + return os.path.join('tests', 'data', 'samplers') + + def test_no_get_chips_implemented(self, dataset: CustomGeoDataset) -> None: + with pytest.raises(TypeError): + GeoSampler(dataset) + def test_iter(self, sampler: CustomGeoSampler) -> None: assert next(iter(sampler)) == BoundingBox(0, 0, 0, 0, 0, 0) @@ -64,6 +98,62 @@ def test_abstract(self, dataset: CustomGeoDataset) -> None: with pytest.raises(TypeError, match="Can't instantiate abstract class"): GeoSampler(dataset) # type: ignore[abstract] + @pytest.mark.parametrize( + 'filtering_file', ['filtering_4x4', 'filtering_4x4.feather'] + ) + def test_filtering_from_path(self, datadir: str, filtering_file: str) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + sampler = GridGeoSampler( + ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10) + ) + iterator = iter(sampler) + + assert len(sampler) == 4 + filtering_path = os.path.join(datadir, filtering_file) + sampler.filter_chips(filtering_path, 'intersects', 'drop') + assert len(sampler) == 3 + assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10) + + def test_filtering_from_gdf(self, datadir: str) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + sampler = GridGeoSampler( + ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10) + ) + iterator = iter(sampler) + + # Dropping first chip + assert len(sampler) == 4 + filtering_gdf = gpd.read_file(os.path.join(datadir, 'filtering_4x4')) + sampler.filter_chips(filtering_gdf, 'intersects', 'drop') + assert len(sampler) == 3 + assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10) + + # Keeping only first chip + sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS) + iterator = iter(sampler) + sampler.filter_chips(filtering_gdf, 'intersects', 'keep') + assert len(sampler) == 1 + assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10) + + def test_set_worker_split(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + sampler = GridGeoSampler( + ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10) + ) + assert len(sampler) == 4 + sampler.set_worker_split(total_workers=4, worker_num=1) + assert len(sampler) == 1 + + def test_save_chips(self, tmpdir_factory) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS) + sampler.save(str(tmpdir_factory.mktemp('out').join('chips'))) + sampler.save(str(tmpdir_factory.mktemp('out').join('chips.feather'))) + @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( @@ -139,6 +229,15 @@ def test_weighted_sampling(self) -> None: for bbox in sampler: assert bbox == BoundingBox(0, 10, 0, 10, 0, 10) + def test_return_as_ts(self) -> None: + ds = CustomGeoDatasetSITS() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + ds.index.insert(1, (0, 10, 0, 10, 15, 20)) + sampler = RandomGeoSampler(ds, 1, 5) + for query in sampler: + assert query.mint == ds.bounds.mint == 0 + assert query.maxt == ds.bounds.maxt == 20 + @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( @@ -156,7 +255,7 @@ class TestGridGeoSampler: def dataset(self) -> CustomGeoDataset: ds = CustomGeoDataset() ds.index.insert(0, (0, 100, 200, 300, 400, 500)) - ds.index.insert(1, (0, 100, 200, 300, 400, 500)) + ds.index.insert(1, (0, 100, 200, 300, 500, 600)) return ds @pytest.fixture( @@ -197,13 +296,13 @@ def test_iter(self, sampler: GridGeoSampler) -> None: assert math.isclose(query.maxx - query.minx, sampler.size[1]) assert math.isclose(query.maxy - query.miny, sampler.size[0]) - assert math.isclose( - query.maxt - query.mint, sampler.roi.maxt - sampler.roi.mint - ) + assert sampler.roi.mint <= query.mint <= query.maxt <= sampler.roi.maxt def test_len(self, sampler: GridGeoSampler) -> None: rows, cols = tile_to_chips(sampler.roi, sampler.size, sampler.stride) - length = rows * cols * 2 # two items in dataset + length = ( + rows * cols * 2 + ) # two spatially but not temporally overlapping items in dataset assert len(sampler) == length def test_roi(self, dataset: CustomGeoDataset) -> None: @@ -243,6 +342,29 @@ def test_float_multiple(self) -> None: assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10) assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10) + def test_dataset_has_regex(self) -> None: + ds = CustomGeoDataset() + ds.filename_regex = r'.*(?Ptest)' + ds.index.insert(0, (0, 10, 0, 10, 0, 10), 'filepath_containing_key_test') + sampler = GridGeoSampler(ds, 1, 2, units=Units.CRS) + assert 'my_key' in sampler.chips.columns + + def test_dataset_has_regex_no_match(self) -> None: + ds = CustomGeoDataset() + ds.filename_regex = r'(?Ptest)' + ds.index.insert(0, (0, 10, 0, 10, 0, 10), 'no_matching_key') + sampler = GridGeoSampler(ds, 1, 2, units=Units.CRS) + assert 'my_key' not in sampler.chips.columns + + def test_return_as_ts(self) -> None: + ds = CustomGeoDatasetSITS() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + ds.index.insert(1, (0, 10, 0, 10, 15, 20)) + sampler = GridGeoSampler(ds, 1, 1) + for query in sampler: + assert query.mint == ds.bounds.mint == 0 + assert query.maxt == ds.bounds.maxt == 20 + @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 64f7ed1ceb0..5d5175c63b4 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -97,6 +97,9 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC): #: a different file format than what it was originally downloaded as. filename_glob = '*' + # Whether to return the dataset as a Timeseries, this will add another dimension to the dataset + return_as_ts = False + # NOTE: according to the Python docs: # # * https://docs.python.org/3/library/exceptions.html#NotImplementedError @@ -980,6 +983,7 @@ def __init__( if not isinstance(ds, GeoDataset): raise ValueError('IntersectionDataset only supports GeoDatasets') + self.return_as_ts = dataset1.return_as_ts or dataset2.return_as_ts self.crs = dataset1.crs self.res = dataset1.res @@ -1140,6 +1144,7 @@ def __init__( if not isinstance(ds, GeoDataset): raise ValueError('UnionDataset only supports GeoDatasets') + self.return_as_ts = dataset1.return_as_ts and dataset2.return_as_ts self.crs = dataset1.crs self.res = dataset1.res diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 094142cb647..2924edad7f4 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -4,17 +4,86 @@ """TorchGeo samplers.""" import abc +import re +import warnings from collections.abc import Callable, Iterable, Iterator +from typing import Any +import geopandas as gpd +import numpy as np +import pandas as pd import torch +from geopandas import GeoDataFrame from rtree.index import Index, Property +from shapely.geometry import box from torch.utils.data import Sampler +from tqdm import tqdm from ..datasets import BoundingBox, GeoDataset from .constants import Units from .utils import _to_tuple, get_random_bounding_box, tile_to_chips +def load_file(path: str | GeoDataFrame) -> GeoDataFrame: + """Load a file from the given path. + + Parameters: + path (str or GeoDataFrame): The path to the file or a GeoDataFrame object. + + Returns: + GeoDataFrame: The loaded file as a GeoDataFrame. + + Raises: + None + + """ + if isinstance(path, GeoDataFrame): + return path + if path.endswith('.feather'): + print(f'Reading feather file: {path}') + return gpd.read_feather(path) + else: + print(f'Reading shapefile: {path}') + return gpd.read_file(path) + + +def _get_regex_groups_as_df(dataset: GeoDataset, hits: list) -> pd.DataFrame: + """Extracts the regex metadata from a list of hits. + + Args: + dataset (GeoDataset): The dataset to sample from. + hits (list): A list of hits. + + Returns: + pandas.DataFrame: A DataFrame containing the extracted file metadata. + """ + has_filename_regex = hasattr(dataset, 'filename_regex') + if has_filename_regex: + filename_regex = re.compile(dataset.filename_regex, re.VERBOSE) + file_metadata = [] + for hit in hits: + if has_filename_regex: + match = re.match(filename_regex, str(hit.object)) + if match: + meta = match.groupdict() + else: + meta = {} + else: + meta = {} + meta.update( + { + 'minx': hit.bounds[0], + 'maxx': hit.bounds[1], + 'miny': hit.bounds[2], + 'maxy': hit.bounds[3], + 'mint': hit.bounds[4], + 'maxt': hit.bounds[5], + } + ) + file_metadata.append(meta) + return pd.DataFrame(file_metadata) + + class GeoSampler(Sampler[BoundingBox], abc.ABC): """Abstract base class for sampling from :class:`~torchgeo.datasets.GeoDataset`. @@ -44,18 +113,103 @@ def __init__(self, dataset: GeoDataset, roi: BoundingBox | None = None) -> None: self.res = dataset.res self.roi = roi + self.dataset = dataset @abc.abstractmethod + def get_chips(self, *args: Any, **kwargs: Any) -> GeoDataFrame: + """Determines the way to get the extend of the chips (samples) of the dataset. + + Should return a GeoDataFrame with the extend of the chips with the columns + geometry, minx, miny, maxx, maxy, mint, maxt, fid. Each row is a chip. + """ + + def filter_chips( + self, + filter_by: str | GeoDataFrame, + predicate: str = 'intersects', + action: str = 'keep', + ) -> None: + """Filter the default set of chips in the sampler down to a specific subset by specifying files + supported by geopandas such as shapefiles, geodatabases or feather files. + + Args: + filter_by: The file or geodataframe for which the geometries will be used during filtering + predicate: Predicate as used in Geopandas sindex.query_bulk + action: What to do with the chips that satisfy the condition by the predicacte. + Can either be "drop" or "keep". + """ + prefilter_leng = len(self.chips) + filtering_gdf = load_file(filter_by).to_crs(self.dataset.crs) + + if action == 'keep': + self.chips = self.chips.iloc[ + list( + set( + self.chips.sindex.query_bulk( + filtering_gdf.geometry, predicate=predicate + )[1] + ) + ) + ].reset_index(drop=True) + elif action == 'drop': + self.chips = self.chips.drop( + index=list( + set( + self.chips.sindex.query_bulk( + filtering_gdf.geometry, predicate=predicate + )[1] + ) + ) + ).reset_index(drop=True) + + self.chips.fid = self.chips.index + print(f'Filter step reduced chips from {prefilter_leng} to {len(self.chips)}') + assert not self.chips.empty, 'No chips left after filtering!' + + def set_worker_split(self, total_workers: int, worker_num: int) -> None: + """Splits the chips in n equal parts for the number of workers and keeps the set of + chips for the specific worker id, convenient if you want to split the chips across + multiple dataloaders for multi-gpu inference. + + Args: + total_workers: The total number of parts to split the chips + worker_num: The id of the worker (which part to keep), starts from 0 + + """ + self.chips = np.array_split(self.chips, total_workers)[worker_num] + + def save(self, path: str, driver: str) -> None: + """Save the chips as a shapefile or feather file""" + if path.endswith('.feather'): + self.chips.to_feather(path) + else: + self.chips.to_file(path, driver=driver) + def __iter__(self) -> Iterator[BoundingBox]: """Return the index of a dataset. Returns: (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset """ + for _, chip in self.chips.iterrows(): + yield BoundingBox( + chip.minx, chip.maxx, chip.miny, chip.maxy, chip.mint, chip.maxt + ) + + def __len__(self) -> int: + """Return the number of samples over the ROI. + + Returns: + number of patches that will be sampled + """ + return len(self.chips) class RandomGeoSampler(GeoSampler): - """Samples elements from a region of interest randomly. + """Differs from TorchGeo's official RandomGeoSampler in that it can sample SITS data. + + Documentation from TorchGeo: + Samples elements from a region of interest randomly. This is particularly useful during training when you want to maximize the size of the dataset and return as many random :term:`chips ` as possible. Note that @@ -105,7 +259,7 @@ def __init__( if units == Units.PIXELS: self.size = (self.size[0] * self.res, self.size[1] * self.res) - self.length = 0 + num_chips = 0 self.hits = [] areas = [] for hit in self.index.intersection(tuple(self.roi), objects=True): @@ -116,43 +270,53 @@ def __init__( ): if bounds.area > 0: rows, cols = tile_to_chips(bounds, self.size) - self.length += rows * cols + num_chips += rows * cols else: - self.length += 1 + num_chips += 1 self.hits.append(hit) areas.append(bounds.area) if length is not None: - self.length = length + num_chips = length + self.length = num_chips # torch.multinomial requires float probabilities > 0 self.areas = torch.tensor(areas, dtype=torch.float) if torch.sum(self.areas) == 0: self.areas += 1 - def __iter__(self) -> Iterator[BoundingBox]: - """Return the index of a dataset. + self.chips = self.get_chips(num_samples=num_chips) - Returns: - (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset - """ - for _ in range(len(self)): + def get_chips(self, num_samples) -> GeoDataFrame: + chips = [] + for _ in tqdm(range(num_samples)): # Choose a random tile, weighted by area idx = torch.multinomial(self.areas, 1) hit = self.hits[idx] - bounds = BoundingBox(*hit.bounds) + hit_bounds = hit.bounds + if self.dataset.return_as_ts: + hit_bounds[-2] = self.dataset.bounds.mint + hit_bounds[-1] = self.dataset.bounds.maxt + bounds = BoundingBox(*hit_bounds) # Choose a random index within that tile - bounding_box = get_random_bounding_box(bounds, self.size, self.res) - - yield bounding_box - - def __len__(self) -> int: - """Return the number of samples in a single epoch. - - Returns: - length of the epoch - """ - return self.length + bbox = get_random_bounding_box(bounds, self.size, self.res) + minx, maxx, miny, maxy, mint, maxt = tuple(bbox) + chip = { + 'geometry': box(minx, miny, maxx, maxy), + 'minx': minx, + 'miny': miny, + 'maxx': maxx, + 'maxy': maxy, + 'mint': mint, + 'maxt': maxt, + } + chips.append(chip) + + print('creating geodataframe... ') + chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) + chips_gdf['fid'] = chips_gdf.index + + return chips_gdf class GridGeoSampler(GeoSampler): @@ -206,33 +370,38 @@ def __init__( self.size = (self.size[0] * self.res, self.size[1] * self.res) self.stride = (self.stride[0] * self.res, self.stride[1] * self.res) - self.hits = [] - for hit in self.index.intersection(tuple(self.roi), objects=True): - bounds = BoundingBox(*hit.bounds) - if ( - bounds.maxx - bounds.minx >= self.size[1] - and bounds.maxy - bounds.miny >= self.size[0] - ): - self.hits.append(hit) - - self.length = 0 - for hit in self.hits: - bounds = BoundingBox(*hit.bounds) - rows, cols = tile_to_chips(bounds, self.size, self.stride) - self.length += rows * cols + hits = self.index.intersection(tuple(self.roi), objects=True) + df_path = _get_regex_groups_as_df(self.dataset, hits) - def __iter__(self) -> Iterator[BoundingBox]: - """Return the index of a dataset. + # Filter out tiles smaller than the chip size + self.df_path = df_path[ + (df_path.maxx - df_path.minx >= self.size[1]) + & (df_path.maxy - df_path.miny >= self.size[0]) + ] - Returns: - (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset - """ - # For each tile... - for hit in self.hits: - bounds = BoundingBox(*hit.bounds) + # Filter out hits in the index that share the same extent + if self.dataset.return_as_ts: + self.df_path.drop_duplicates( + subset=['minx', 'maxx', 'miny', 'maxy'], inplace=True + ) + else: + self.df_path.drop_duplicates( + subset=['minx', 'maxx', 'miny', 'maxy', 'mint', 'maxt'], inplace=True + ) + + self.chips = self.get_chips() + + def get_chips(self) -> GeoDataFrame: + print('generating samples... ') + optional_keys = set(self.df_path.keys()) - set( + ['geometry', 'minx', 'maxx', 'miny', 'maxy', 'mint', 'maxt'] + ) + chips = [] + for _, row in tqdm(self.df_path.iterrows(), total=len(self.df_path)): + bounds = BoundingBox( + row.minx, row.maxx, row.miny, row.maxy, row.mint, row.maxt + ) rows, cols = tile_to_chips(bounds, self.size, self.stride) - mint = bounds.mint - maxt = bounds.maxt # For each row... for i in range(rows): @@ -244,15 +413,37 @@ def __iter__(self) -> Iterator[BoundingBox]: minx = bounds.minx + j * self.stride[1] maxx = minx + self.size[1] - yield BoundingBox(minx, maxx, miny, maxy, mint, maxt) - - def __len__(self) -> int: - """Return the number of samples over the ROI. + if self.dataset.return_as_ts: + mint = self.dataset.bounds.mint + maxt = self.dataset.bounds.maxt + else: + mint = bounds.mint + maxt = bounds.maxt + + chip = { + 'geometry': box(minx, miny, maxx, maxy), + 'minx': minx, + 'miny': miny, + 'maxx': maxx, + 'maxy': maxy, + 'mint': mint, + 'maxt': maxt, + } + for key in optional_keys: + if key in row.keys(): + chip[key] = row[key] + + chips.append(chip) + + if chips: + print('creating geodataframe... ') + chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) + chips_gdf['fid'] = chips_gdf.index - Returns: - number of patches that will be sampled - """ - return self.length + else: + warnings.warn('Sampler has no chips, check your inputs') + chips_gdf = GeoDataFrame() + return chips_gdf class PreChippedGeoSampler(GeoSampler): @@ -289,23 +480,29 @@ def __init__( for hit in self.index.intersection(tuple(self.roi), objects=True): self.hits.append(hit) - def __iter__(self) -> Iterator[BoundingBox]: - """Return the index of a dataset. + self.chips = self.get_chips() - Returns: - (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset - """ + def get_chips(self) -> GeoDataFrame: generator: Callable[[int], Iterable[int]] = range if self.shuffle: generator = torch.randperm - for idx in generator(len(self)): - yield BoundingBox(*self.hits[idx].bounds) - - def __len__(self) -> int: - """Return the number of samples over the ROI. - - Returns: - number of patches that will be sampled - """ - return len(self.hits) + chips = [] + for idx in generator(len(self.hits)): + minx, maxx, miny, maxy, mint, maxt = self.hits[idx].bounds + chip = { + 'geometry': box(minx, miny, maxx, maxy), + 'minx': minx, + 'miny': miny, + 'maxx': maxx, + 'maxy': maxy, + 'mint': mint, + 'maxt': maxt, + } + chips.append(chip) + + print('creating geodataframe... ') + chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) + chips_gdf['fid'] = chips_gdf.index + + return chips_gdf