Skip to content

Commit

Permalink
mudancas temporarias
Browse files Browse the repository at this point in the history
  • Loading branch information
PedroConrado committed Dec 3, 2024
1 parent 1321f59 commit 564866f
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 25 deletions.
3 changes: 3 additions & 0 deletions terratorch/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
from terratorch.datamodules.biomassters import BioMasstersNonGeoDataModule
from terratorch.datamodules.forestnet import ForestNetNonGeoDataModule

# miscellaneous datamodules
from terratorch.datamodules.openearthmap import OpenEarthMapNonGeoDataModule

# Generic classification datamodule
from terratorch.datamodules.sen4map import Sen4MapLucasDataModule

Expand Down
2 changes: 1 addition & 1 deletion terratorch/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"GenericNonGeoSegmentationDataset",
"GenericNonGeoPixelwiseRegressionDataset",
"GenericNonGeoClassificationDataset",
"GenericNonGeoRegressionDataset",
#"GenericNonGeoRegressionDataset",
"BurnIntensityNonGeo",
"CarbonFluxNonGeo",
"Landslide4SenseNonGeo",
Expand Down
2 changes: 1 addition & 1 deletion terratorch/datasets/fire_scars.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure
ax[3].imshow(image)
ax[3].imshow(mask, cmap="jet", alpha=0.3, norm=norm)

if prediction:
if "prediction" in sample:
ax[4].title.set_text("Predicted Mask")
ax[4].imshow(prediction, cmap="jet", norm=norm)

Expand Down
2 changes: 1 addition & 1 deletion terratorch/datasets/landslide4sense.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) ->
ax[2].axis("off")

if "prediction" in sample:
prediction = sample["prediction"].numpy()
prediction = sample["prediction"]
ax[3].imshow(prediction, cmap=cmap, norm=norm)
ax[3].set_title("Predicted Mask")
ax[3].axis("off")
Expand Down
41 changes: 20 additions & 21 deletions terratorch/datasets/multi_temporal_crop_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,37 +235,35 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure
raise ValueError(msg)

images = sample["image"]
if not self.expand_temporal_dimension:
images = rearrange(images, "(channels time) h w -> channels time h w", channels=len(self.bands))
images = images[rgb_indices, ...] # Shape: (T, 3, H, W)

# RGB -> channels-last
images = images[rgb_indices, ...].permute(1, 2, 3, 0).numpy()
mask = sample["mask"].numpy()

images = [clip_image(img) for img in images]
processed_images = []
for t in range(self.time_steps):
img = images[t]
img = img.permute(1, 2, 0)
img = img.numpy()
img = clip_image(img)
processed_images.append(img)

mask = sample["mask"].numpy()
if "prediction" in sample:
prediction = sample["prediction"]
num_images += 1
else:
prediction = None

fig, ax = plt.subplots(1, num_images, figsize=(12, 5), layout="compressed")

ax[0].axis("off")

norm = mpl.colors.Normalize(vmin=0, vmax=self.num_classes - 1)
for i, img in enumerate(processed_images):
ax[i + 1].axis("off")
ax[i + 1].title.set_text(f"T{i}")
ax[i + 1].imshow(img)

for i, img in enumerate(images):
ax[i+1].axis("off")
ax[i+1].title.set_text(f"T{i}")
ax[i+1].imshow(img)

ax[self.time_steps+1].axis("off")
ax[self.time_steps+1].title.set_text("Ground Truth Mask")
ax[self.time_steps+1].imshow(mask, cmap="jet", norm=norm)
ax[self.time_steps + 1].axis("off")
ax[self.time_steps + 1].title.set_text("Ground Truth Mask")
ax[self.time_steps + 1].imshow(mask, cmap="jet", norm=norm)

if prediction:
if "prediction" in sample:
prediction = sample["prediction"]
ax[self.time_steps + 1].axis("off")
ax[self.time_steps+2].title.set_text("Predicted Mask")
ax[self.time_steps+2].imshow(prediction, cmap="jet", norm=norm)

Expand All @@ -274,6 +272,7 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure
handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data]
labels = [n for k, c, n in legend_data]
ax[0].legend(handles, labels, loc="center")

if suptitle is not None:
plt.suptitle(suptitle)

Expand Down
2 changes: 1 addition & 1 deletion terratorch/datasets/sen1floods11.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure
ax[3].imshow(image)
ax[3].imshow(mask, cmap="jet", alpha=0.3, norm=norm)

if prediction:
if "prediction" in sample:
ax[4].title.set_text("Predicted Mask")
ax[4].imshow(prediction, cmap="jet", norm=norm)

Expand Down

0 comments on commit 564866f

Please sign in to comment.