From 6c2b1cbc5f7a0e30edf1764d4443e43f22262d43 Mon Sep 17 00:00:00 2001
From: rijuld <errijuldahiya@gmail.com>
Date: Sun, 19 Jan 2025 05:02:05 -0500
Subject: [PATCH] added more test coverage for extract and verify

---
 tests/datasets/test_substation.py | 31 ++++++++++++++++++++++++++++++-
 torchgeo/datasets/substation.py   |  4 ++--
 2 files changed, 32 insertions(+), 3 deletions(-)

diff --git a/tests/datasets/test_substation.py b/tests/datasets/test_substation.py
index 30f2e7b42d2..c31eadbdc7f 100644
--- a/tests/datasets/test_substation.py
+++ b/tests/datasets/test_substation.py
@@ -12,7 +12,7 @@
 import pytest
 import torch
 
-from torchgeo.datasets import Substation
+from torchgeo.datasets import Substation, DatasetNotFoundError
 
 
 class TestSubstation:
@@ -156,6 +156,35 @@ def test_download(
         mock_extract_archive.assert_called()
         assert mock_extract_archive.call_count == 2
 
+    def test_not_downloaded(self, tmp_path: Path) -> None:
+        with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
+            Substation(
+                bands=[1, 2, 3],
+                use_timepoints=True,
+                mask_2d=True,
+                timepoint_aggregation='median',
+                num_of_timepoints=4,
+                root=tmp_path,
+            )
+
+    def test_extract(self, tmp_path: Path) -> None:
+        filename = Substation.filename_images
+        maskname = Substation.filename_masks
+        shutil.copyfile(
+            os.path.join('tests', 'data', 'substation', filename), tmp_path / filename
+        )
+        shutil.copyfile(
+            os.path.join('tests', 'data', 'substation', maskname), tmp_path / maskname
+        )
+        Substation(
+            bands=[1, 2, 3],
+            use_timepoints=True,
+            mask_2d=True,
+            timepoint_aggregation='median',
+            num_of_timepoints=4,
+            root=tmp_path,
+        )
+
 
 if __name__ == '__main__':
     pytest.main([__file__])
diff --git a/torchgeo/datasets/substation.py b/torchgeo/datasets/substation.py
index b3d1e21b7cc..89825d99af5 100644
--- a/torchgeo/datasets/substation.py
+++ b/torchgeo/datasets/substation.py
@@ -12,6 +12,7 @@
 from matplotlib.figure import Figure
 from torch import Tensor
 
+from .errors import DatasetNotFoundError
 from .geo import NonGeoDataset
 from .utils import download_url, extract_archive
 
@@ -213,14 +214,13 @@ def _verify(self) -> None:
         # Check if the tar.gz files for images and masks have already been downloaded
         image_exists = os.path.exists(os.path.join(self.root, self.filename_images))
         mask_exists = os.path.exists(os.path.join(self.root, self.filename_masks))
-
         if image_exists and mask_exists:
             self._extract()
             return
 
         # If dataset files are missing and download is not allowed, raise an error
         if not getattr(self, 'download', True):
-            raise FileNotFoundError(
+            raise DatasetNotFoundError(
                 f'Dataset files not found in {self.root}. Enable downloading or provide the files.'
             )