Skip to content

Commit

Permalink
Merge pull request #65 from IBM/pin/albumentations
Browse files Browse the repository at this point in the history
pin albumentations
  • Loading branch information
CarlosGomes98 authored Jul 31, 2024
2 parents 9df5f5f + 9a263a8 commit 90d3e49
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ dependencies = [
"geobench>=1.0.0",
"mlflow>=2.12.1",
# broken due to https://github.com/Lightning-AI/pytorch-lightning/issues/19977
"lightning>=2, <=2.2.5"
"lightning>=2, <=2.2.5",
# see issue #64
"albumentations<=1.4.10"
]

[project.optional-dependencies]
Expand Down
5 changes: 3 additions & 2 deletions terratorch/datasets/generic_pixel_wise_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,15 @@ def __getitem__(self, index: int) -> dict[str, Any]:
"image": image.astype(np.float32) * self.constant_scale,
"mask": self._load_file(self.segmentation_mask_files[index], nan_replace=self.no_label_replace).to_numpy()[
0
],
"filename": self.image_files[index],
]
}

if self.reduce_zero_label:
output["mask"] -= 1
if self.transform:
output = self.transform(**output)
output["filename"] = self.image_files[index]

return output

def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArray:
Expand Down
7 changes: 3 additions & 4 deletions terratorch/datasets/generic_scalar_label_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,12 @@ def __getitem__(self, index: int) -> dict[str, Any]:

output = {
"image": image.astype(np.float32) * self.constant_scale,
"label": label,
"filename": self.samples[index][
0
], # samples is an attribute of ImageFolder. Contains a tuple of (Path, Target)
"label": label, # samples is an attribute of ImageFolder. Contains a tuple of (Path, Target)
}
if self.transforms:
output = self.transforms(**output)
output["filename"] = self.image_files[index]

return output

def _load_file(self, path) -> xr.DataArray:
Expand Down

0 comments on commit 90d3e49

Please sign in to comment.