Skip to content

Commit

Permalink
Strings are allowed to define bands
Browse files Browse the repository at this point in the history
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
  • Loading branch information
Joao-L-S-Almeida committed Jul 22, 2024
1 parent e48965f commit 5dba481
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
12 changes: 6 additions & 6 deletions terratorch/datamodules/generic_pixel_wise_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def __init__(
test_split: Path | None = None,
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
dataset_bands: list[HLSBands | int | list[int]] | None = None,
predict_dataset_bands: list[HLSBands | int | list[int]] | None = None,
output_bands: list[HLSBands | int | list[int]] | None = None,
dataset_bands: list[HLSBands | int | list[int] | str] | None = None,
predict_dataset_bands: list[HLSBands | int | list[int] | str ] | None = None,
output_bands: list[HLSBands | int | list[int] | str] | None = None,
constant_scale: float = 1,
rgb_indices: list[int] | None = None,
train_transform: A.Compose | None | list[A.BasicTransform] = None,
Expand Down Expand Up @@ -330,9 +330,9 @@ def __init__(
test_split: Path | None = None,
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
dataset_bands: list[HLSBands | int | list[int]] | None = None,
predict_dataset_bands: list[HLSBands | int | list[int]] | None = None,
output_bands: list[HLSBands | int | list[int]] | None = None,
dataset_bands: list[HLSBands | int | list[int] | str ] | None = None,
predict_dataset_bands: list[HLSBands | int | list[int] | str ] | None = None,
output_bands: list[HLSBands | int | list[int] | str ] | None = None,
constant_scale: float = 1,
rgb_indices: list[int] | None = None,
train_transform: A.Compose | None | list[A.BasicTransform] = None,
Expand Down
17 changes: 10 additions & 7 deletions terratorch/datasets/generic_pixel_wise_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
expected 0. Defaults to False.
"""
super().__init__()

self.split_file = split

label_data_root = label_data_root if label_data_root is not None else data_root
Expand Down Expand Up @@ -136,7 +137,7 @@ def __init__(
if bands_type == str:
raise UserWarning("When the bands are defined as str, guarantee your input files"+
"are organized by band and all have its specific name.")

if self.output_bands and not self.dataset_bands:
msg = "If output bands provided, dataset_bands must also be provided"
return Exception(msg) # noqa: PLE0101
Expand All @@ -146,7 +147,9 @@ def __init__(
if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
msg = "Output bands must be a subset of dataset bands"
raise Exception(msg)

self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

else:
self.filter_indices = None

Expand Down Expand Up @@ -176,7 +179,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
if self.transform:
output = self.transform(**output)
return output

def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArray:
data = rioxarray.open_rasterio(path, masked=True)
if nan_replace is not None:
Expand All @@ -200,7 +203,7 @@ def _bands_as_int_or_str(self, dataset_bands, output_bands) -> type:
band_type[b] = str
else:
pass
if band_type.cound(band_type[0]) == len(band_type)
if band_type.cound(band_type[0]) == len(band_type):
return band_type[0]
else:
raise Exception("The bands must be or all str or all int.")
Expand Down Expand Up @@ -232,8 +235,8 @@ def __init__(
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
rgb_indices: list[str] | None = None,
dataset_bands: list[HLSBands | int | list[int]] | None = None,
output_bands: list[HLSBands | int | list[int]] | None = None,
dataset_bands: list[HLSBands | int | list[int] | str ] | None = None,
output_bands: list[HLSBands | int | list[int] | str ] | None = None,
class_names: list[str] | None = None,
constant_scale: float = 1,
transform: A.Compose | None = None,
Expand Down Expand Up @@ -399,8 +402,8 @@ def __init__(
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
rgb_indices: list[int] | None = None,
dataset_bands: list[HLSBands | int | list[int]] | None = None,
output_bands: list[HLSBands | int | list[int]] | None = None,
dataset_bands: list[HLSBands | int | list[int] | str ] | None = None,
output_bands: list[HLSBands | int | list[int] | str ] | None = None,
constant_scale: float = 1,
transform: A.Compose | None = None,
no_data_replace: float | None = None,
Expand Down

0 comments on commit 5dba481

Please sign in to comment.