Skip to content

Commit

Permalink
Reformatting the source code using black
Browse files Browse the repository at this point in the history
Signed-off-by: Joao Lucas de Sousa Almeida <[email protected]>
  • Loading branch information
Joao-L-S-Almeida committed Jul 7, 2024
1 parent 308d540 commit 1e00471
Show file tree
Hide file tree
Showing 49 changed files with 738 additions and 703 deletions.
2 changes: 1 addition & 1 deletion examples/scripts/convert_sen1floods11_splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

data = np.genfromtxt(input_file, delimiter=',', dtype=str)

col1 = data[:,0].tolist()
col1 = data[:, 0].tolist()

col1_ = ["_".join(i.split("_")[:2]) for i in col1]

Expand Down
25 changes: 13 additions & 12 deletions examples/scripts/instantiate_satmae_backbone.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import torch
import torch
import numpy as np

from models_mae import MaskedAutoencoderViT

kwargs = {"img_size":224,
"patch_size":16,
"in_chans":3,
"embed_dim":1024,
"depth":24,
"num_heads":16,
"decoder_embed_dim":512,
"decoder_depth":8,
"decoder_num_heads":16,
"mlp_ratio":4.}
kwargs = {
"img_size": 224,
"patch_size": 16,
"in_chans": 3,
"embed_dim": 1024,
"depth": 24,
"num_heads": 16,
"decoder_embed_dim": 512,
"decoder_depth": 8,
"decoder_num_heads": 16,
"mlp_ratio": 4.0,
}

vit_mae = MaskedAutoencoderViT(**kwargs)

Expand All @@ -29,4 +31,3 @@

print(f"Output shape: {reconstructed.shape}")
print("Done.")

13 changes: 9 additions & 4 deletions terratorch/cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def write_tiff(img_wrt, filename, metadata):
return filename


def save_prediction(prediction, input_file_name, out_dir, dtype:str="int16"):
def save_prediction(prediction, input_file_name, out_dir, dtype: str = "int16"):
mask, metadata = open_tiff(input_file_name)
mask = np.where(mask == metadata["nodata"], 1, 0)
mask = np.max(mask, axis=0)
Expand Down Expand Up @@ -310,10 +310,11 @@ def instantiate_classes(self) -> None:
config = self.config
if hasattr(config, "predict_output_dir"):
self.trainer.predict_output_dir = config.predict_output_dir

if hasattr(config, "out_dtype"):
self.trainer.out_dtype = config.out_dtype


def build_lightning_cli(
args: ArgsType = None,
run=True, # noqa: FBT002
Expand Down Expand Up @@ -413,8 +414,12 @@ def from_config(
]

if predict_dataset_bands is not None:
arguments.extend([ "--data.init_args.predict_dataset_bands",
"[" + ",".join(predict_dataset_bands) + "]",])
arguments.extend(
[
"--data.init_args.predict_dataset_bands",
"[" + ",".join(predict_dataset_bands) + "]",
]
)

cli = build_lightning_cli(arguments, run=False)
trainer = cli.trainer
Expand Down
2 changes: 1 addition & 1 deletion terratorch/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,5 @@
"MChesapeakeLandcoverNonGeoDataModule",
"MPv4gerSegNonGeoDataModule",
"MSACropTypeNonGeoDataModule",
"MNeonTreeNonGeoDataModule"
"MNeonTreeNonGeoDataModule",
)
27 changes: 16 additions & 11 deletions terratorch/datamodules/m_SA_crop_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"WATER_VAPOR": 69.904566,
"SWIR_1": 83.626811,
"SWIR_2": 65.767679,
"CLOUD_PROBABILITY": 0.0
"CLOUD_PROBABILITY": 0.0,
}

STDS = {
Expand All @@ -38,36 +38,41 @@
"WATER_VAPOR": 21.877766438821954,
"SWIR_1": 28.14418826277069,
"SWIR_2": 27.2346215312965,
"CLOUD_PROBABILITY": 0.0
"CLOUD_PROBABILITY": 0.0,
}


class MSACropTypeNonGeoDataModule(NonGeoDataModule):
def __init__(
self,
batch_size: int = 8,
num_workers: int = 0,
self,
batch_size: int = 8,
num_workers: int = 0,
data_root: str = "./",
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
aug: AugmentationSequential = None,
**kwargs: Any
**kwargs: Any,
) -> None:

super().__init__(MSACropTypeNonGeo, batch_size, num_workers, **kwargs)

bands = kwargs.get("bands", MSACropTypeNonGeo.all_band_names)
self.means = torch.tensor([MEANS[b] for b in bands])
self.stds = torch.tensor([STDS[b] for b in bands])
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.data_root = data_root
self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image", "mask"]) if aug is None else aug

self.aug = (
AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image", "mask"])
if aug is None
else aug
)

def setup(self, stage: str) -> None:
if stage in ["fit"]:
self.train_dataset = self.dataset_class(
self.train_dataset = self.dataset_class(
split="train", data_root=self.data_root, transform=self.train_transform, **self.kwargs
)
if stage in ["fit", "validate"]:
Expand All @@ -76,5 +81,5 @@ def setup(self, stage: str) -> None:
)
if stage in ["test"]:
self.test_dataset = self.dataset_class(
split="test",data_root=self.data_root, transform=self.test_transform, **self.kwargs
split="test", data_root=self.data_root, transform=self.test_transform, **self.kwargs
)
65 changes: 34 additions & 31 deletions terratorch/datamodules/m_bigearthnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,62 +9,65 @@
from terratorch.datamodules.utils import wrap_in_compose_is_list

MEANS = {
"COASTAL_AEROSOL": 378.4027,
"BLUE": 482.2730,
"GREEN": 706.5345,
"RED": 720.9285,
"RED_EDGE_1": 1100.6688,
"COASTAL_AEROSOL": 378.4027,
"BLUE": 482.2730,
"GREEN": 706.5345,
"RED": 720.9285,
"RED_EDGE_1": 1100.6688,
"RED_EDGE_2": 1909.2914,
"RED_EDGE_3": 2191.6985,
"NIR_BROAD": 2336.8706,
"NIR_NARROW": 2394.7449,
"WATER_VAPOR": 2368.3127,
"SWIR_1": 1875.2487,
"SWIR_2": 1229.3818
"RED_EDGE_3": 2191.6985,
"NIR_BROAD": 2336.8706,
"NIR_NARROW": 2394.7449,
"WATER_VAPOR": 2368.3127,
"SWIR_1": 1875.2487,
"SWIR_2": 1229.3818,
}

STDS = {
"COASTAL_AEROSOL": 157.5666,
"BLUE": 255.0429,
"GREEN": 303.1750,
"RED": 391.2943,
"RED_EDGE_1": 380.7916,
"RED_EDGE_2": 551.6558,
STDS = {
"COASTAL_AEROSOL": 157.5666,
"BLUE": 255.0429,
"GREEN": 303.1750,
"RED": 391.2943,
"RED_EDGE_1": 380.7916,
"RED_EDGE_2": 551.6558,
"RED_EDGE_3": 638.8196,
"NIR_BROAD": 744.2009,
"NIR_NARROW": 675.4041,
"WATER_VAPOR": 561.0154,
"SWIR_1": 563.4095,
"SWIR_2": 479.1786
"NIR_NARROW": 675.4041,
"WATER_VAPOR": 561.0154,
"SWIR_1": 563.4095,
"SWIR_2": 479.1786,
}


class MBigEarthNonGeoDataModule(NonGeoDataModule):
def __init__(
self,
batch_size: int = 8,
num_workers: int = 0,
self,
batch_size: int = 8,
num_workers: int = 0,
data_root: str = "./",
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
aug: AugmentationSequential = None,
**kwargs: Any
**kwargs: Any,
) -> None:

super().__init__(MBigEarthNonGeo, batch_size, num_workers, **kwargs)

bands = kwargs.get("bands", MBigEarthNonGeo.all_band_names)
self.means = torch.tensor([MEANS[b] for b in bands])
self.stds = torch.tensor([STDS[b] for b in bands])
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.data_root = data_root
self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug

self.aug = (
AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug
)

def setup(self, stage: str) -> None:
if stage in ["fit"]:
self.train_dataset = self.dataset_class(
self.train_dataset = self.dataset_class(
split="train", data_root=self.data_root, transform=self.train_transform, **self.kwargs
)
if stage in ["fit", "validate"]:
Expand All @@ -73,5 +76,5 @@ def setup(self, stage: str) -> None:
)
if stage in ["test"]:
self.test_dataset = self.dataset_class(
split="test",data_root=self.data_root, transform=self.test_transform, **self.kwargs
split="test", data_root=self.data_root, transform=self.test_transform, **self.kwargs
)
26 changes: 14 additions & 12 deletions terratorch/datamodules/m_brick_kiln.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"WATER_VAPOR": 1129.8171906000355,
"CIRRUS": 83.27188605598549,
"SWIR_1": 90.54924599052214,
"SWIR_2": 68.98768652434848
"SWIR_2": 68.98768652434848,
}

STDS = {
Expand All @@ -38,37 +38,39 @@
"WATER_VAPOR": 704.0219637458916,
"CIRRUS": 36.355745901131705,
"SWIR_1": 28.004671947623894,
"SWIR_2": 24.268892726362033
"SWIR_2": 24.268892726362033,
}

class MBrickKilnNonGeoDataModule(NonGeoDataModule):

class MBrickKilnNonGeoDataModule(NonGeoDataModule):
def __init__(
self,
batch_size: int = 8,
num_workers: int = 0,
self,
batch_size: int = 8,
num_workers: int = 0,
data_root: str = "./",
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
aug: AugmentationSequential = None,
**kwargs: Any
**kwargs: Any,
) -> None:

super().__init__(MBrickKilnNonGeo, batch_size, num_workers, **kwargs)

bands = kwargs.get("bands", MBrickKilnNonGeo.all_band_names)
self.means = torch.tensor([MEANS[b] for b in bands])
self.stds = torch.tensor([STDS[b] for b in bands])
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.data_root = data_root
self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug

self.aug = (
AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug
)

def setup(self, stage: str) -> None:
if stage in ["fit"]:
self.train_dataset = self.dataset_class(
self.train_dataset = self.dataset_class(
split="train", data_root=self.data_root, transform=self.train_transform, **self.kwargs
)
if stage in ["fit", "validate"]:
Expand All @@ -77,5 +79,5 @@ def setup(self, stage: str) -> None:
)
if stage in ["test"]:
self.test_dataset = self.dataset_class(
split="test",data_root=self.data_root, transform=self.test_transform, **self.kwargs
split="test", data_root=self.data_root, transform=self.test_transform, **self.kwargs
)
30 changes: 17 additions & 13 deletions terratorch/datamodules/m_cashew_plantation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"WATER_VAPOR": 2852.87451171875,
"SWIR_1": 2463.933349609375,
"SWIR_2": 1600.9207763671875,
"CLOUD_PROBABILITY": 0.010281000286340714
"CLOUD_PROBABILITY": 0.010281000286340714,
}

STDS = {
Expand All @@ -38,37 +38,41 @@
"WATER_VAPOR": 413.8980407714844,
"SWIR_1": 494.97430419921875,
"SWIR_2": 514.4229736328125,
"CLOUD_PROBABILITY": 0.3447800576686859
"CLOUD_PROBABILITY": 0.3447800576686859,
}

class MBeninSmallHolderCashewsNonGeoDataModule(NonGeoDataModule):

class MBeninSmallHolderCashewsNonGeoDataModule(NonGeoDataModule):
def __init__(
self,
batch_size: int = 8,
num_workers: int = 0,
self,
batch_size: int = 8,
num_workers: int = 0,
data_root: str = "./",
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
aug: AugmentationSequential = None,
**kwargs: Any
**kwargs: Any,
) -> None:

super().__init__(MBeninSmallHolderCashewsNonGeo, batch_size, num_workers, **kwargs)

bands = kwargs.get("bands", MBeninSmallHolderCashewsNonGeo.all_band_names)
self.means = torch.tensor([MEANS[b] for b in bands])
self.stds = torch.tensor([STDS[b] for b in bands])
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.data_root = data_root
self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image", "mask"]) if aug is None else aug

self.aug = (
AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image", "mask"])
if aug is None
else aug
)

def setup(self, stage: str) -> None:
if stage in ["fit"]:
self.train_dataset = self.dataset_class(
self.train_dataset = self.dataset_class(
split="train", data_root=self.data_root, transform=self.train_transform, **self.kwargs
)
if stage in ["fit", "validate"]:
Expand All @@ -77,5 +81,5 @@ def setup(self, stage: str) -> None:
)
if stage in ["test"]:
self.test_dataset = self.dataset_class(
split="test",data_root=self.data_root, transform=self.test_transform, **self.kwargs
)
split="test", data_root=self.data_root, transform=self.test_transform, **self.kwargs
)
Loading

0 comments on commit 1e00471

Please sign in to comment.