diff --git a/terratorch/datamodules/__init__.py b/terratorch/datamodules/__init__.py index 1c2af9b0..b75da89b 100644 --- a/terratorch/datamodules/__init__.py +++ b/terratorch/datamodules/__init__.py @@ -39,6 +39,9 @@ from terratorch.datamodules.biomassters import BioMasstersNonGeoDataModule from terratorch.datamodules.forestnet import ForestNetNonGeoDataModule +# miscellaneous datamodules +from terratorch.datamodules.openearthmap import OpenEarthMapNonGeoDataModule + # Generic classification datamodule from terratorch.datamodules.sen4map import Sen4MapLucasDataModule diff --git a/terratorch/datasets/__init__.py b/terratorch/datasets/__init__.py index 9f7f3bb1..41cd2fa7 100644 --- a/terratorch/datasets/__init__.py +++ b/terratorch/datasets/__init__.py @@ -50,7 +50,7 @@ "GenericNonGeoSegmentationDataset", "GenericNonGeoPixelwiseRegressionDataset", "GenericNonGeoClassificationDataset", - "GenericNonGeoRegressionDataset", + #"GenericNonGeoRegressionDataset", "BurnIntensityNonGeo", "CarbonFluxNonGeo", "Landslide4SenseNonGeo", diff --git a/terratorch/datasets/fire_scars.py b/terratorch/datasets/fire_scars.py index 7ddad452..f5b65516 100644 --- a/terratorch/datasets/fire_scars.py +++ b/terratorch/datasets/fire_scars.py @@ -198,7 +198,7 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure ax[3].imshow(image) ax[3].imshow(mask, cmap="jet", alpha=0.3, norm=norm) - if prediction: + if "prediction" in sample: ax[4].title.set_text("Predicted Mask") ax[4].imshow(prediction, cmap="jet", norm=norm) diff --git a/terratorch/datasets/landslide4sense.py b/terratorch/datasets/landslide4sense.py index 54b71e06..5c949ded 100644 --- a/terratorch/datasets/landslide4sense.py +++ b/terratorch/datasets/landslide4sense.py @@ -137,7 +137,7 @@ def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> ax[2].axis("off") if "prediction" in sample: - prediction = sample["prediction"].numpy() + prediction = sample["prediction"] ax[3].imshow(prediction, cmap=cmap, norm=norm) ax[3].set_title("Predicted Mask") ax[3].axis("off") diff --git a/terratorch/datasets/multi_temporal_crop_classification.py b/terratorch/datasets/multi_temporal_crop_classification.py index 32e5421f..709800d4 100644 --- a/terratorch/datasets/multi_temporal_crop_classification.py +++ b/terratorch/datasets/multi_temporal_crop_classification.py @@ -235,37 +235,35 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure raise ValueError(msg) images = sample["image"] - if not self.expand_temporal_dimension: - images = rearrange(images, "(channels time) h w -> channels time h w", channels=len(self.bands)) + images = images[rgb_indices, ...] # Shape: (T, 3, H, W) - # RGB -> channels-last - images = images[rgb_indices, ...].permute(1, 2, 3, 0).numpy() - mask = sample["mask"].numpy() - - images = [clip_image(img) for img in images] + processed_images = [] + for t in range(self.time_steps): + img = images[t] + img = img.permute(1, 2, 0) + img = img.numpy() + img = clip_image(img) + processed_images.append(img) + mask = sample["mask"].numpy() if "prediction" in sample: - prediction = sample["prediction"] num_images += 1 - else: - prediction = None - fig, ax = plt.subplots(1, num_images, figsize=(12, 5), layout="compressed") - ax[0].axis("off") norm = mpl.colors.Normalize(vmin=0, vmax=self.num_classes - 1) + for i, img in enumerate(processed_images): + ax[i + 1].axis("off") + ax[i + 1].title.set_text(f"T{i}") + ax[i + 1].imshow(img) - for i, img in enumerate(images): - ax[i+1].axis("off") - ax[i+1].title.set_text(f"T{i}") - ax[i+1].imshow(img) - - ax[self.time_steps+1].axis("off") - ax[self.time_steps+1].title.set_text("Ground Truth Mask") - ax[self.time_steps+1].imshow(mask, cmap="jet", norm=norm) + ax[self.time_steps + 1].axis("off") + ax[self.time_steps + 1].title.set_text("Ground Truth Mask") + ax[self.time_steps + 1].imshow(mask, cmap="jet", norm=norm) - if prediction: + if "prediction" in sample: + prediction = sample["prediction"] + ax[self.time_steps + 1].axis("off") ax[self.time_steps+2].title.set_text("Predicted Mask") ax[self.time_steps+2].imshow(prediction, cmap="jet", norm=norm) @@ -274,6 +272,7 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data] labels = [n for k, c, n in legend_data] ax[0].legend(handles, labels, loc="center") + if suptitle is not None: plt.suptitle(suptitle) diff --git a/terratorch/datasets/sen1floods11.py b/terratorch/datasets/sen1floods11.py index e6fe9362..b36965c7 100644 --- a/terratorch/datasets/sen1floods11.py +++ b/terratorch/datasets/sen1floods11.py @@ -228,7 +228,7 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure ax[3].imshow(image) ax[3].imshow(mask, cmap="jet", alpha=0.3, norm=norm) - if prediction: + if "prediction" in sample: ax[4].title.set_text("Predicted Mask") ax[4].imshow(prediction, cmap="jet", norm=norm)