From e0efd15584dac5ea3b9b060c53e5079c173c3854 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Thu, 20 Feb 2025 09:18:41 +0100 Subject: [PATCH] fix padding for decoders Signed-off-by: Benedikt Blumenstiel --- terratorch/models/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/terratorch/models/utils.py b/terratorch/models/utils.py index 2e68b947..89e7e57e 100644 --- a/terratorch/models/utils.py +++ b/terratorch/models/utils.py @@ -31,6 +31,9 @@ def pad_images(imgs: Tensor, patch_size: int | list, padding: str) -> Tensor: else: raise ValueError(f'patch size {patch_size} not valid, must be int or list of ints with length 1, 2 or 3.') + # Double the patch size to ensure the resulting number of patches is divisible by 2 (required for many decoders) + p_h, p_w = p_h * 2, p_w * 2 + if p_t > 1 and len(imgs.shape) < 5: raise ValueError(f"Multi-temporal padding requested (p_t = {p_t}) " f"but no multi-temporal data provided (data shape = {imgs.shape}).")