Skip to content

Commit

Permalink
added more test coverage for extract and verify
Browse files Browse the repository at this point in the history
  • Loading branch information
rijuld committed Jan 19, 2025
1 parent 7c8c71a commit 6c2b1cb
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
31 changes: 30 additions & 1 deletion tests/datasets/test_substation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pytest
import torch

from torchgeo.datasets import Substation
from torchgeo.datasets import Substation, DatasetNotFoundError


class TestSubstation:

Check failure on line 18 in tests/datasets/test_substation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

tests/datasets/test_substation.py:4:1: I001 Import block is un-sorted or un-formatted
Expand Down Expand Up @@ -156,6 +156,35 @@ def test_download(
mock_extract_archive.assert_called()
assert mock_extract_archive.call_count == 2

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
Substation(
bands=[1, 2, 3],
use_timepoints=True,
mask_2d=True,
timepoint_aggregation='median',
num_of_timepoints=4,
root=tmp_path,
)

def test_extract(self, tmp_path: Path) -> None:
filename = Substation.filename_images
maskname = Substation.filename_masks
shutil.copyfile(
os.path.join('tests', 'data', 'substation', filename), tmp_path / filename
)
shutil.copyfile(
os.path.join('tests', 'data', 'substation', maskname), tmp_path / maskname
)
Substation(
bands=[1, 2, 3],
use_timepoints=True,
mask_2d=True,
timepoint_aggregation='median',
num_of_timepoints=4,
root=tmp_path,
)


if __name__ == '__main__':
pytest.main([__file__])
4 changes: 2 additions & 2 deletions torchgeo/datasets/substation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from matplotlib.figure import Figure
from torch import Tensor

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import download_url, extract_archive

Expand Down Expand Up @@ -213,14 +214,13 @@ def _verify(self) -> None:
# Check if the tar.gz files for images and masks have already been downloaded
image_exists = os.path.exists(os.path.join(self.root, self.filename_images))
mask_exists = os.path.exists(os.path.join(self.root, self.filename_masks))

if image_exists and mask_exists:
self._extract()
return

# If dataset files are missing and download is not allowed, raise an error
if not getattr(self, 'download', True):
raise FileNotFoundError(
raise DatasetNotFoundError(
f'Dataset files not found in {self.root}. Enable downloading or provide the files.'
)

Expand Down

0 comments on commit 6c2b1cb

Please sign in to comment.