Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ruff format
Browse files Browse the repository at this point in the history
rijuld committed Jan 19, 2025
1 parent 8c918a8 commit 39668ca
Showing 2 changed files with 10 additions and 6 deletions.
6 changes: 3 additions & 3 deletions tests/datasets/test_substation.py
Original file line number Diff line number Diff line change
@@ -87,9 +87,9 @@ def test_getitem_semantic(self, config: dict[str, Any]) -> None:

x = dataset[0]
assert isinstance(x, dict), f'Expected dict, got {type(x)}'
assert isinstance(
x['image'], torch.Tensor
), 'Expected image to be a torch.Tensor'
assert isinstance(x['image'], torch.Tensor), (
'Expected image to be a torch.Tensor'
)
assert isinstance(x['mask'], torch.Tensor), 'Expected mask to be a torch.Tensor'

def test_len(self, dataset: Substation) -> None:
10 changes: 7 additions & 3 deletions torchgeo/datamodules/substation.py
Original file line number Diff line number Diff line change
@@ -75,15 +75,19 @@ def __init__(
self.image_resize = image_resize
self.mask_resize = mask_resize
self.num_of_timepoints = num_of_timepoints
self.geo_transforms = geo_transforms if geo_transforms is not None else self._identity
self.color_transforms = color_transforms if color_transforms is not None else self._identity
self.geo_transforms = (
geo_transforms if geo_transforms is not None else self._identity
)
self.color_transforms = (
color_transforms if color_transforms is not None else self._identity
)
self.image_resize = image_resize if image_resize is not None else self._identity
self.mask_resize = mask_resize if mask_resize is not None else self._identity

def _identity(self, x: torch.Tensor) -> torch.Tensor:
"""Identity function for default transformations."""
return x

def setup(self, stage: str) -> None:
"""Set up datasets.

0 comments on commit 39668ca

Please sign in to comment.