diff --git a/tests/datamodules/test_geo.py b/tests/datamodules/test_geo.py index 8e5fd13d292..d1924984ea7 100644 --- a/tests/datamodules/test_geo.py +++ b/tests/datamodules/test_geo.py @@ -11,6 +11,7 @@ from matplotlib.figure import Figure from rasterio.crs import CRS from torch import Tensor +from geopandas import GeoDataFrame from torchgeo.datamodules import ( GeoDataModule, @@ -182,7 +183,7 @@ def test_zero_length_sampler(self) -> None: dm = CustomGeoDataModule() dm.dataset = CustomGeoDataset() dm.sampler = RandomGeoSampler(dm.dataset, 1, 1) - dm.sampler.length = 0 + dm.sampler.chips = GeoDataFrame() msg = r'CustomGeoDataModule\.sampler has length 0.' with pytest.raises(MisconfigurationException, match=msg): dm.train_dataloader() diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 00f5889a22f..6fdbf4712fc 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -190,6 +190,7 @@ def test_empty(self, dataset: CustomGeoDataset) -> None: assert len(sampler) == 0 def test_refresh_samples(self, dataset: CustomGeoDataset) -> None: + dataset.index.insert(0, (0, 100, 200, 300, 400, 500)) sampler = RandomGeoSampler(dataset, 5, length=1) samples = list(sampler) assert len(sampler) == 1