Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reformatting the source code using black #43

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/scripts/convert_sen1floods11_splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

data = np.genfromtxt(input_file, delimiter=',', dtype=str)

col1 = data[:,0].tolist()
col1 = data[:, 0].tolist()

col1_ = ["_".join(i.split("_")[:2]) for i in col1]

Expand Down
25 changes: 13 additions & 12 deletions examples/scripts/instantiate_satmae_backbone.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import torch
import torch
import numpy as np

from models_mae import MaskedAutoencoderViT

kwargs = {"img_size":224,
"patch_size":16,
"in_chans":3,
"embed_dim":1024,
"depth":24,
"num_heads":16,
"decoder_embed_dim":512,
"decoder_depth":8,
"decoder_num_heads":16,
"mlp_ratio":4.}
kwargs = {
"img_size": 224,
"patch_size": 16,
"in_chans": 3,
"embed_dim": 1024,
"depth": 24,
"num_heads": 16,
"decoder_embed_dim": 512,
"decoder_depth": 8,
"decoder_num_heads": 16,
"mlp_ratio": 4.0,
}

vit_mae = MaskedAutoencoderViT(**kwargs)

Expand All @@ -29,4 +31,3 @@

print(f"Output shape: {reconstructed.shape}")
print("Done.")

13 changes: 9 additions & 4 deletions terratorch/cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def write_tiff(img_wrt, filename, metadata):
return filename


def save_prediction(prediction, input_file_name, out_dir, dtype:str="int16"):
def save_prediction(prediction, input_file_name, out_dir, dtype: str = "int16"):
mask, metadata = open_tiff(input_file_name)
mask = np.where(mask == metadata["nodata"], 1, 0)
mask = np.max(mask, axis=0)
Expand Down Expand Up @@ -310,10 +310,11 @@ def instantiate_classes(self) -> None:
config = self.config
if hasattr(config, "predict_output_dir"):
self.trainer.predict_output_dir = config.predict_output_dir

if hasattr(config, "out_dtype"):
self.trainer.out_dtype = config.out_dtype


def build_lightning_cli(
args: ArgsType = None,
run=True, # noqa: FBT002
Expand Down Expand Up @@ -413,8 +414,12 @@ def from_config(
]

if predict_dataset_bands is not None:
arguments.extend([ "--data.init_args.predict_dataset_bands",
"[" + ",".join(predict_dataset_bands) + "]",])
arguments.extend(
[
"--data.init_args.predict_dataset_bands",
"[" + ",".join(predict_dataset_bands) + "]",
]
)

cli = build_lightning_cli(arguments, run=False)
trainer = cli.trainer
Expand Down
2 changes: 1 addition & 1 deletion terratorch/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,5 @@
"MChesapeakeLandcoverNonGeoDataModule",
"MPv4gerSegNonGeoDataModule",
"MSACropTypeNonGeoDataModule",
"MNeonTreeNonGeoDataModule"
"MNeonTreeNonGeoDataModule",
)
27 changes: 16 additions & 11 deletions terratorch/datamodules/m_SA_crop_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"WATER_VAPOR": 69.904566,
"SWIR_1": 83.626811,
"SWIR_2": 65.767679,
"CLOUD_PROBABILITY": 0.0
"CLOUD_PROBABILITY": 0.0,
}

STDS = {
Expand All @@ -38,36 +38,41 @@
"WATER_VAPOR": 21.877766438821954,
"SWIR_1": 28.14418826277069,
"SWIR_2": 27.2346215312965,
"CLOUD_PROBABILITY": 0.0
"CLOUD_PROBABILITY": 0.0,
}


class MSACropTypeNonGeoDataModule(NonGeoDataModule):
def __init__(
self,
batch_size: int = 8,
num_workers: int = 0,
self,
batch_size: int = 8,
num_workers: int = 0,
data_root: str = "./",
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
aug: AugmentationSequential = None,
**kwargs: Any
**kwargs: Any,
) -> None:

super().__init__(MSACropTypeNonGeo, batch_size, num_workers, **kwargs)

bands = kwargs.get("bands", MSACropTypeNonGeo.all_band_names)
self.means = torch.tensor([MEANS[b] for b in bands])
self.stds = torch.tensor([STDS[b] for b in bands])
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.data_root = data_root
self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image", "mask"]) if aug is None else aug

self.aug = (
AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image", "mask"])
if aug is None
else aug
)

def setup(self, stage: str) -> None:
if stage in ["fit"]:
self.train_dataset = self.dataset_class(
self.train_dataset = self.dataset_class(
split="train", data_root=self.data_root, transform=self.train_transform, **self.kwargs
)
if stage in ["fit", "validate"]:
Expand All @@ -76,5 +81,5 @@ def setup(self, stage: str) -> None:
)
if stage in ["test"]:
self.test_dataset = self.dataset_class(
split="test",data_root=self.data_root, transform=self.test_transform, **self.kwargs
split="test", data_root=self.data_root, transform=self.test_transform, **self.kwargs
)
65 changes: 34 additions & 31 deletions terratorch/datamodules/m_bigearthnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,62 +9,65 @@
from terratorch.datamodules.utils import wrap_in_compose_is_list

MEANS = {
"COASTAL_AEROSOL": 378.4027,
"BLUE": 482.2730,
"GREEN": 706.5345,
"RED": 720.9285,
"RED_EDGE_1": 1100.6688,
"COASTAL_AEROSOL": 378.4027,
"BLUE": 482.2730,
"GREEN": 706.5345,
"RED": 720.9285,
"RED_EDGE_1": 1100.6688,
"RED_EDGE_2": 1909.2914,
"RED_EDGE_3": 2191.6985,
"NIR_BROAD": 2336.8706,
"NIR_NARROW": 2394.7449,
"WATER_VAPOR": 2368.3127,
"SWIR_1": 1875.2487,
"SWIR_2": 1229.3818
"RED_EDGE_3": 2191.6985,
"NIR_BROAD": 2336.8706,
"NIR_NARROW": 2394.7449,
"WATER_VAPOR": 2368.3127,
"SWIR_1": 1875.2487,
"SWIR_2": 1229.3818,
}

STDS = {
"COASTAL_AEROSOL": 157.5666,
"BLUE": 255.0429,
"GREEN": 303.1750,
"RED": 391.2943,
"RED_EDGE_1": 380.7916,
"RED_EDGE_2": 551.6558,
STDS = {
"COASTAL_AEROSOL": 157.5666,
"BLUE": 255.0429,
"GREEN": 303.1750,
"RED": 391.2943,
"RED_EDGE_1": 380.7916,
"RED_EDGE_2": 551.6558,
"RED_EDGE_3": 638.8196,
"NIR_BROAD": 744.2009,
"NIR_NARROW": 675.4041,
"WATER_VAPOR": 561.0154,
"SWIR_1": 563.4095,
"SWIR_2": 479.1786
"NIR_NARROW": 675.4041,
"WATER_VAPOR": 561.0154,
"SWIR_1": 563.4095,
"SWIR_2": 479.1786,
}


class MBigEarthNonGeoDataModule(NonGeoDataModule):
def __init__(
self,
batch_size: int = 8,
num_workers: int = 0,
self,
batch_size: int = 8,
num_workers: int = 0,
data_root: str = "./",
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
aug: AugmentationSequential = None,
**kwargs: Any
**kwargs: Any,
) -> None:

super().__init__(MBigEarthNonGeo, batch_size, num_workers, **kwargs)

bands = kwargs.get("bands", MBigEarthNonGeo.all_band_names)
self.means = torch.tensor([MEANS[b] for b in bands])
self.stds = torch.tensor([STDS[b] for b in bands])
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.data_root = data_root
self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug

self.aug = (
AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug
)

def setup(self, stage: str) -> None:
if stage in ["fit"]:
self.train_dataset = self.dataset_class(
self.train_dataset = self.dataset_class(
split="train", data_root=self.data_root, transform=self.train_transform, **self.kwargs
)
if stage in ["fit", "validate"]:
Expand All @@ -73,5 +76,5 @@ def setup(self, stage: str) -> None:
)
if stage in ["test"]:
self.test_dataset = self.dataset_class(
split="test",data_root=self.data_root, transform=self.test_transform, **self.kwargs
split="test", data_root=self.data_root, transform=self.test_transform, **self.kwargs
)
26 changes: 14 additions & 12 deletions terratorch/datamodules/m_brick_kiln.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"WATER_VAPOR": 1129.8171906000355,
"CIRRUS": 83.27188605598549,
"SWIR_1": 90.54924599052214,
"SWIR_2": 68.98768652434848
"SWIR_2": 68.98768652434848,
}

STDS = {
Expand All @@ -38,37 +38,39 @@
"WATER_VAPOR": 704.0219637458916,
"CIRRUS": 36.355745901131705,
"SWIR_1": 28.004671947623894,
"SWIR_2": 24.268892726362033
"SWIR_2": 24.268892726362033,
}

class MBrickKilnNonGeoDataModule(NonGeoDataModule):

class MBrickKilnNonGeoDataModule(NonGeoDataModule):
def __init__(
self,
batch_size: int = 8,
num_workers: int = 0,
self,
batch_size: int = 8,
num_workers: int = 0,
data_root: str = "./",
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
aug: AugmentationSequential = None,
**kwargs: Any
**kwargs: Any,
) -> None:

super().__init__(MBrickKilnNonGeo, batch_size, num_workers, **kwargs)

bands = kwargs.get("bands", MBrickKilnNonGeo.all_band_names)
self.means = torch.tensor([MEANS[b] for b in bands])
self.stds = torch.tensor([STDS[b] for b in bands])
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.data_root = data_root
self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug

self.aug = (
AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug
)

def setup(self, stage: str) -> None:
if stage in ["fit"]:
self.train_dataset = self.dataset_class(
self.train_dataset = self.dataset_class(
split="train", data_root=self.data_root, transform=self.train_transform, **self.kwargs
)
if stage in ["fit", "validate"]:
Expand All @@ -77,5 +79,5 @@ def setup(self, stage: str) -> None:
)
if stage in ["test"]:
self.test_dataset = self.dataset_class(
split="test",data_root=self.data_root, transform=self.test_transform, **self.kwargs
split="test", data_root=self.data_root, transform=self.test_transform, **self.kwargs
)
30 changes: 17 additions & 13 deletions terratorch/datamodules/m_cashew_plantation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"WATER_VAPOR": 2852.87451171875,
"SWIR_1": 2463.933349609375,
"SWIR_2": 1600.9207763671875,
"CLOUD_PROBABILITY": 0.010281000286340714
"CLOUD_PROBABILITY": 0.010281000286340714,
}

STDS = {
Expand All @@ -38,37 +38,41 @@
"WATER_VAPOR": 413.8980407714844,
"SWIR_1": 494.97430419921875,
"SWIR_2": 514.4229736328125,
"CLOUD_PROBABILITY": 0.3447800576686859
"CLOUD_PROBABILITY": 0.3447800576686859,
}

class MBeninSmallHolderCashewsNonGeoDataModule(NonGeoDataModule):

class MBeninSmallHolderCashewsNonGeoDataModule(NonGeoDataModule):
def __init__(
self,
batch_size: int = 8,
num_workers: int = 0,
self,
batch_size: int = 8,
num_workers: int = 0,
data_root: str = "./",
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
aug: AugmentationSequential = None,
**kwargs: Any
**kwargs: Any,
) -> None:

super().__init__(MBeninSmallHolderCashewsNonGeo, batch_size, num_workers, **kwargs)

bands = kwargs.get("bands", MBeninSmallHolderCashewsNonGeo.all_band_names)
self.means = torch.tensor([MEANS[b] for b in bands])
self.stds = torch.tensor([STDS[b] for b in bands])
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.data_root = data_root
self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image", "mask"]) if aug is None else aug

self.aug = (
AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image", "mask"])
if aug is None
else aug
)

def setup(self, stage: str) -> None:
if stage in ["fit"]:
self.train_dataset = self.dataset_class(
self.train_dataset = self.dataset_class(
split="train", data_root=self.data_root, transform=self.train_transform, **self.kwargs
)
if stage in ["fit", "validate"]:
Expand All @@ -77,5 +81,5 @@ def setup(self, stage: str) -> None:
)
if stage in ["test"]:
self.test_dataset = self.dataset_class(
split="test",data_root=self.data_root, transform=self.test_transform, **self.kwargs
)
split="test", data_root=self.data_root, transform=self.test_transform, **self.kwargs
)
Loading
Loading