Skip to content

Commit

Permalink
Add docs and compute mask stats
Browse files Browse the repository at this point in the history
Signed-off-by: Francesc Marti Escofet <[email protected]>
  • Loading branch information
fmartiescofet committed Jan 7, 2025
1 parent 2d69f06 commit 6b5b3d4
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 6 deletions.
22 changes: 18 additions & 4 deletions terratorch/cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from tqdm import tqdm

import terratorch.datamodules
from terratorch.utils import compute_statistics
from terratorch.utils import compute_mask_statistics, compute_statistics
import terratorch.tasks # noqa: F401
from terratorch.datamodules import ( # noqa: F401
GenericNonGeoClassificationDataModule,
Expand Down Expand Up @@ -574,6 +574,16 @@ def inference(self, file_path: Path) -> torch.Tensor:

class MyTrainer(Trainer):
def compute_statistics(self, datamodule: LightningDataModule, **kwargs) -> None:
"""
Compute the dataset statistics for the training dataset.
This method will compute the mean and standard deviation of the image data and the count and percentage of each
unique value in the masks in case these are int and the mean and standard deviation of the mask values in case
these are floats. The statistics are computed using the entire training dataset and are printed to the logger.
Please note that this method assumes that there is only one train dataloader in the datamodule and that the
dataset does not have any transforms that may introduce randomness.
"""
datamodule.setup("fit")
original_dataloader = datamodule.train_dataloader()
if not isinstance(original_dataloader, DataLoader):
Expand All @@ -588,6 +598,10 @@ def compute_statistics(self, datamodule: LightningDataModule, **kwargs) -> None:
pin_memory=original_dataloader.pin_memory,
drop_last=False,
)
mean, std = compute_statistics(new_dataloader)

logger.info(yaml.dump({"means": mean, "stds": std}))
image_stats = compute_statistics(new_dataloader)
logger.info("Image statistics:")
logger.info(yaml.dump(image_stats))
if "mask" in datamodule.train_dataloader().dataset[0]:
mask_stats = compute_mask_statistics(new_dataloader)
logger.info("Mask statistics:")
logger.info(yaml.dump(mask_stats))
45 changes: 43 additions & 2 deletions terratorch/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import math
from collections import Counter

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm


def compute_statistics(dataloader: DataLoader) -> tuple[list[float], list[float]]:
def compute_statistics(dataloader: DataLoader) -> dict[str, list[float]]:
n_bands = dataloader.dataset[0]["image"].shape[0]
n_data = torch.zeros([n_bands], dtype=torch.int64)
sum_data = torch.zeros([n_bands], dtype=torch.float64)
Expand All @@ -25,5 +28,43 @@ def compute_statistics(dataloader: DataLoader) -> tuple[list[float], list[float]

variance = sum_squared / n_data
std = torch.sqrt(variance)
return {"means": mean.numpy().tolist(), "stds": std.numpy().tolist()}


def compute_mask_statistics(dataloader: DataLoader) -> dict[int, dict[str, int | float]] | dict[str, float]:
if torch.is_floating_point(dataloader.dataset[0]["mask"]):
return compute_float_mask_statistics(dataloader)
else:
return compute_int_mask_statistics(dataloader)


def compute_int_mask_statistics(dataloader: DataLoader) -> dict[int, dict[str, int | float]]:
counter = Counter()
for batch in tqdm(dataloader, desc="Compute counts"):
masks: torch.Tensor = batch["mask"]
counter.update(masks.flatten().tolist())

stats = {}
for key, count in counter.items():
stats[key] = {"count": count, "percentage": count / counter.total()}
return stats

return mean.numpy().tolist(), std.numpy().tolist()

def compute_float_mask_statistics(dataloader: DataLoader) -> dict[str, float]:
n_data = 0
total = 0.0

for batch in tqdm(dataloader, desc="Compute mask mean"):
masks: torch.Tensor = batch["mask"]
total += masks.sum().item()
n_data += masks.numel()
mean = total / n_data

sum_squared = 0.0
for batch in tqdm(dataloader, desc="Compute mask variance"):
masks = batch["mask"]
sum_squared += ((masks - mean) ** 2).sum().item()

variance = sum_squared / n_data
std = math.sqrt(variance)
return {"mean": mean, "std": std}

0 comments on commit 6b5b3d4

Please sign in to comment.