diff --git a/tests/datasets/test_substation.py b/tests/datasets/test_substation.py index 30f2e7b42d2..c31eadbdc7f 100644 --- a/tests/datasets/test_substation.py +++ b/tests/datasets/test_substation.py @@ -12,7 +12,7 @@ import pytest import torch -from torchgeo.datasets import Substation +from torchgeo.datasets import Substation, DatasetNotFoundError class TestSubstation: @@ -156,6 +156,35 @@ def test_download( mock_extract_archive.assert_called() assert mock_extract_archive.call_count == 2 + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + Substation( + bands=[1, 2, 3], + use_timepoints=True, + mask_2d=True, + timepoint_aggregation='median', + num_of_timepoints=4, + root=tmp_path, + ) + + def test_extract(self, tmp_path: Path) -> None: + filename = Substation.filename_images + maskname = Substation.filename_masks + shutil.copyfile( + os.path.join('tests', 'data', 'substation', filename), tmp_path / filename + ) + shutil.copyfile( + os.path.join('tests', 'data', 'substation', maskname), tmp_path / maskname + ) + Substation( + bands=[1, 2, 3], + use_timepoints=True, + mask_2d=True, + timepoint_aggregation='median', + num_of_timepoints=4, + root=tmp_path, + ) + if __name__ == '__main__': pytest.main([__file__]) diff --git a/torchgeo/datasets/substation.py b/torchgeo/datasets/substation.py index b3d1e21b7cc..89825d99af5 100644 --- a/torchgeo/datasets/substation.py +++ b/torchgeo/datasets/substation.py @@ -12,6 +12,7 @@ from matplotlib.figure import Figure from torch import Tensor +from .errors import DatasetNotFoundError from .geo import NonGeoDataset from .utils import download_url, extract_archive @@ -213,14 +214,13 @@ def _verify(self) -> None: # Check if the tar.gz files for images and masks have already been downloaded image_exists = os.path.exists(os.path.join(self.root, self.filename_images)) mask_exists = os.path.exists(os.path.join(self.root, self.filename_masks)) - if image_exists and mask_exists: self._extract() return # If dataset files are missing and download is not allowed, raise an error if not getattr(self, 'download', True): - raise FileNotFoundError( + raise DatasetNotFoundError( f'Dataset files not found in {self.root}. Enable downloading or provide the files.' )