diff --git a/terratorch/datamodules/biomassters.py b/terratorch/datamodules/biomassters.py index eaa04471..4e2cc05d 100644 --- a/terratorch/datamodules/biomassters.py +++ b/terratorch/datamodules/biomassters.py @@ -74,6 +74,7 @@ def __init__( train_transform: A.Compose | None | list[A.BasicTransform] = None, val_transform: A.Compose | None | list[A.BasicTransform] = None, test_transform: A.Compose | None | list[A.BasicTransform] = None, + predict_transform: A.Compose | None | list[A.BasicTransform] = None, aug: AugmentationSequential = None, drop_last: bool = True, sensors: Sequence[str] = ["S1", "S2"], @@ -107,6 +108,7 @@ def __init__( self.train_transform = wrap_in_compose_is_list(train_transform) self.val_transform = wrap_in_compose_is_list(val_transform) self.test_transform = wrap_in_compose_is_list(test_transform) + self.predict_transform = wrap_in_compose_is_list(predict_transform) if len(sensors) == 1: self.aug = Normalize(self.means[sensors[0]], self.stds[sensors[0]]) if aug is None else aug else: @@ -176,6 +178,24 @@ def setup(self, stage: str) -> None: seed=self.seed, use_four_frames=self.use_four_frames, ) + if stage in ["predict"]: + self.predict_dataset = self.dataset_class( + split="test", + root=self.data_root, + transform=self.predict_transform, + bands=self.bands, + mask_mean=self.mask_mean, + mask_std=self.mask_std, + sensors=self.sensors, + as_time_series=self.as_time_series, + metadata_filename=self.metadata_filename, + max_cloud_percentage=self.max_cloud_percentage, + max_red_mean=self.max_red_mean, + include_corrupt=self.include_corrupt, + subset=self.subset, + seed=self.seed, + use_four_frames=self.use_four_frames, + ) def _dataloader_factory(self, split: str): dataset = self._valid_attribute(f"{split}_dataset", "dataset") diff --git a/terratorch/datamodules/burn_intensity.py b/terratorch/datamodules/burn_intensity.py index 6c5b3343..4c371fcb 100644 --- a/terratorch/datamodules/burn_intensity.py +++ b/terratorch/datamodules/burn_intensity.py @@ -37,6 +37,7 @@ def __init__( train_transform: A.Compose | None | list[A.BasicTransform] = None, val_transform: A.Compose | None | list[A.BasicTransform] = None, test_transform: A.Compose | None | list[A.BasicTransform] = None, + predict_transform: A.Compose | None | list[A.BasicTransform] = None, use_full_data: bool = True, no_data_replace: float | None = 0.0001, no_label_replace: int | None = -1, @@ -52,6 +53,7 @@ def __init__( self.train_transform = wrap_in_compose_is_list(train_transform) self.val_transform = wrap_in_compose_is_list(val_transform) self.test_transform = wrap_in_compose_is_list(test_transform) + self.predict_transform = wrap_in_compose_is_list(predict_transform) self.aug = NormalizeWithTimesteps(means, stds) self.use_full_data = use_full_data self.no_data_replace = no_data_replace @@ -92,3 +94,14 @@ def setup(self, stage: str) -> None: no_label_replace=self.no_label_replace, use_metadata=self.use_metadata, ) + if stage in ["predict"]: + self.predict_dataset = self.dataset_class( + split="val", + data_root=self.data_root, + transform=self.predict_transform, + bands=self.bands, + use_full_data=self.use_full_data, + no_data_replace=self.no_data_replace, + no_label_replace=self.no_label_replace, + use_metadata=self.use_metadata, + ) diff --git a/terratorch/datamodules/carbonflux.py b/terratorch/datamodules/carbonflux.py index fb2f145f..7697cc27 100644 --- a/terratorch/datamodules/carbonflux.py +++ b/terratorch/datamodules/carbonflux.py @@ -50,6 +50,7 @@ def __init__( train_transform: A.Compose | None | list[A.BasicTransform] = None, val_transform: A.Compose | None | list[A.BasicTransform] = None, test_transform: A.Compose | None | list[A.BasicTransform] = None, + predict_transform: A.Compose | None | list[A.BasicTransform] = None, aug: AugmentationSequential = None, no_data_replace: float | None = 0.0001, use_metadata: bool = False, @@ -72,6 +73,7 @@ def __init__( self.train_transform = wrap_in_compose_is_list(train_transform) self.val_transform = wrap_in_compose_is_list(val_transform) self.test_transform = wrap_in_compose_is_list(test_transform) + self.predict_transform = wrap_in_compose_is_list(predict_transform) self.aug = MultimodalNormalize(means, stds) if aug is None else aug self.no_data_replace = no_data_replace self.use_metadata = use_metadata @@ -110,3 +112,14 @@ def setup(self, stage: str) -> None: no_data_replace=self.no_data_replace, use_metadata=self.use_metadata, ) + if stage in ["predict"]: + self.predict_dataset = self.dataset_class( + split="test", + data_root=self.data_root, + transform=self.predict_transform, + bands=self.bands, + gpp_mean=self.mask_means, + gpp_std=self.mask_std, + no_data_replace=self.no_data_replace, + use_metadata=self.use_metadata, + ) diff --git a/terratorch/datamodules/fire_scars.py b/terratorch/datamodules/fire_scars.py index 39038cae..0938f8d6 100644 --- a/terratorch/datamodules/fire_scars.py +++ b/terratorch/datamodules/fire_scars.py @@ -46,6 +46,7 @@ def __init__( train_transform: A.Compose | None | list[A.BasicTransform] = None, val_transform: A.Compose | None | list[A.BasicTransform] = None, test_transform: A.Compose | None | list[A.BasicTransform] = None, + predict_transform: A.Compose | None | list[A.BasicTransform] = None, drop_last: bool = True, no_data_replace: float | None = 0, no_label_replace: int | None = -1, @@ -61,6 +62,7 @@ def __init__( self.train_transform = wrap_in_compose_is_list(train_transform) self.val_transform = wrap_in_compose_is_list(val_transform) self.test_transform = wrap_in_compose_is_list(test_transform) + self.predict_transform = wrap_in_compose_is_list(predict_transform) self.aug = AugmentationSequential(K.Normalize(means, stds), data_keys=["image"]) self.drop_last = drop_last self.no_data_replace = no_data_replace @@ -98,6 +100,16 @@ def setup(self, stage: str) -> None: no_label_replace=self.no_label_replace, use_metadata=self.use_metadata, ) + if stage in ["predict"]: + self.predict_dataset = self.dataset_class( + split="val", + data_root=self.data_root, + transform=self.predict_transform, + bands=self.bands, + no_data_replace=self.no_data_replace, + no_label_replace=self.no_label_replace, + use_metadata=self.use_metadata, + ) def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders. diff --git a/terratorch/datamodules/forestnet.py b/terratorch/datamodules/forestnet.py index c78108d5..f46dd567 100644 --- a/terratorch/datamodules/forestnet.py +++ b/terratorch/datamodules/forestnet.py @@ -42,6 +42,7 @@ def __init__( train_transform: A.Compose | None | list[A.BasicTransform] = None, val_transform: A.Compose | None | list[A.BasicTransform] = None, test_transform: A.Compose | None | list[A.BasicTransform] = None, + predict_transform: A.Compose | None | list[A.BasicTransform] = None, fraction: float = 1.0, aug: AugmentationSequential = None, use_metadata: bool = False, @@ -57,6 +58,7 @@ def __init__( self.train_transform = wrap_in_compose_is_list(train_transform) self.val_transform = wrap_in_compose_is_list(val_transform) self.test_transform = wrap_in_compose_is_list(test_transform) + self.predict_transform = wrap_in_compose_is_list(predict_transform) self.aug = Normalize(self.means, self.stds) if aug is None else aug self.fraction = fraction self.use_metadata = use_metadata @@ -92,3 +94,13 @@ def setup(self, stage: str) -> None: fraction=self.fraction, use_metadata=self.use_metadata, ) + if stage in ["predict"]: + self.predict_dataset = self.dataset_class( + split="test", + data_root=self.data_root, + label_map=self.label_map, + transform=self.predict_transform, + bands=self.bands, + fraction=self.fraction, + use_metadata=self.use_metadata, + ) diff --git a/terratorch/datamodules/geobench_data_module.py b/terratorch/datamodules/geobench_data_module.py index 1e509037..785f4c46 100644 --- a/terratorch/datamodules/geobench_data_module.py +++ b/terratorch/datamodules/geobench_data_module.py @@ -23,6 +23,7 @@ def __init__( train_transform: A.Compose | None | list[A.BasicTransform] = None, val_transform: A.Compose | None | list[A.BasicTransform] = None, test_transform: A.Compose | None | list[A.BasicTransform] = None, + predict_transform: A.Compose | None | list[A.BasicTransform] = None, aug: AugmentationSequential = None, partition: str = "default", **kwargs: Any, @@ -35,6 +36,7 @@ def __init__( self.train_transform = wrap_in_compose_is_list(train_transform) self.val_transform = wrap_in_compose_is_list(val_transform) self.test_transform = wrap_in_compose_is_list(test_transform) + self.predict_transform = wrap_in_compose_is_list(predict_transform) self.data_root = data_root self.partition = partition self.aug = ( @@ -69,3 +71,12 @@ def setup(self, stage: str) -> None: bands=self.bands, **self.kwargs, ) + if stage in ["predict"]: + self.predict_dataset = self.dataset_class( + split="test", + data_root=self.data_root, + transform=self.predict_transform, + partition=self.partition, + bands=self.bands, + **self.kwargs, + ) diff --git a/terratorch/datamodules/landslide4sense.py b/terratorch/datamodules/landslide4sense.py index 0e843907..84df0188 100644 --- a/terratorch/datamodules/landslide4sense.py +++ b/terratorch/datamodules/landslide4sense.py @@ -56,6 +56,7 @@ def __init__( train_transform: A.Compose | None | list[A.BasicTransform] = None, val_transform: A.Compose | None | list[A.BasicTransform] = None, test_transform: A.Compose | None | list[A.BasicTransform] = None, + predict_transform: A.Compose | None | list[A.BasicTransform] = None, aug: AugmentationSequential = None, **kwargs: Any, ) -> None: @@ -68,6 +69,7 @@ def __init__( self.train_transform = wrap_in_compose_is_list(train_transform) self.val_transform = wrap_in_compose_is_list(val_transform) self.test_transform = wrap_in_compose_is_list(test_transform) + self.predict_transform = wrap_in_compose_is_list(predict_transform) self.aug = ( AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug ) @@ -94,3 +96,10 @@ def setup(self, stage: str) -> None: transform=self.test_transform, bands=self.bands ) + if stage in ["predict"]: + self.predict_dataset = self.dataset_class( + split="test", + data_root=self.data_root, + transform=self.predict_transform, + bands=self.bands + ) diff --git a/terratorch/datamodules/multi_temporal_crop_classification.py b/terratorch/datamodules/multi_temporal_crop_classification.py index 4957e088..14452af4 100644 --- a/terratorch/datamodules/multi_temporal_crop_classification.py +++ b/terratorch/datamodules/multi_temporal_crop_classification.py @@ -41,6 +41,7 @@ def __init__( train_transform: A.Compose | None | list[A.BasicTransform] = None, val_transform: A.Compose | None | list[A.BasicTransform] = None, test_transform: A.Compose | None | list[A.BasicTransform] = None, + predict_transform: A.Compose | None | list[A.BasicTransform] = None, drop_last: bool = True, no_data_replace: float | None = 0, no_label_replace: int | None = -1, @@ -58,6 +59,7 @@ def __init__( self.train_transform = wrap_in_compose_is_list(train_transform) self.val_transform = wrap_in_compose_is_list(val_transform) self.test_transform = wrap_in_compose_is_list(test_transform) + self.predict_transform = wrap_in_compose_is_list(predict_transform) self.aug = Normalize(self.means, self.stds) self.drop_last = drop_last self.no_data_replace = no_data_replace @@ -103,6 +105,18 @@ def setup(self, stage: str) -> None: reduce_zero_label = self.reduce_zero_label, use_metadata=self.use_metadata, ) + if stage in ["predict"]: + self.predict_dataset = self.dataset_class( + split="val", + data_root=self.data_root, + transform=self.predict_transform, + bands=self.bands, + no_data_replace=self.no_data_replace, + no_label_replace=self.no_label_replace, + expand_temporal_dimension = self.expand_temporal_dimension, + reduce_zero_label = self.reduce_zero_label, + use_metadata=self.use_metadata, + ) def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders. diff --git a/terratorch/datamodules/open_sentinel_map.py b/terratorch/datamodules/open_sentinel_map.py index fca6d730..36365b21 100644 --- a/terratorch/datamodules/open_sentinel_map.py +++ b/terratorch/datamodules/open_sentinel_map.py @@ -17,11 +17,10 @@ def __init__( train_transform: A.Compose | None | list[A.BasicTransform] = None, val_transform: A.Compose | None | list[A.BasicTransform] = None, test_transform: A.Compose | None | list[A.BasicTransform] = None, + predict_transform: A.Compose | None | list[A.BasicTransform] = None, spatial_interpolate_and_stack_temporally: bool = True, # noqa: FBT001, FBT002 pad_image: int | None = None, truncate_image: int | None = None, - target: int = 0, - pick_random_pair: bool = True, # noqa: FBT002, FBT001 **kwargs: Any, ) -> None: super().__init__( @@ -34,11 +33,10 @@ def __init__( self.spatial_interpolate_and_stack_temporally = spatial_interpolate_and_stack_temporally self.pad_image = pad_image self.truncate_image = truncate_image - self.target = target - self.pick_random_pair = pick_random_pair self.train_transform = wrap_in_compose_is_list(train_transform) self.val_transform = wrap_in_compose_is_list(val_transform) self.test_transform = wrap_in_compose_is_list(test_transform) + self.predict_transform = wrap_in_compose_is_list(predict_transform) self.data_root = data_root self.kwargs = kwargs @@ -52,8 +50,6 @@ def setup(self, stage: str) -> None: spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally, pad_image = self.pad_image, truncate_image = self.truncate_image, - target = self.target, - pick_random_pair = self.pick_random_pair, **self.kwargs, ) if stage in ["fit", "validate"]: @@ -65,8 +61,6 @@ def setup(self, stage: str) -> None: spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally, pad_image = self.pad_image, truncate_image = self.truncate_image, - target = self.target, - pick_random_pair = self.pick_random_pair, **self.kwargs, ) if stage in ["test"]: @@ -78,7 +72,16 @@ def setup(self, stage: str) -> None: spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally, pad_image = self.pad_image, truncate_image = self.truncate_image, - target = self.target, - pick_random_pair = self.pick_random_pair, + **self.kwargs, + ) + if stage in ["predict"]: + self.predict_dataset = OpenSentinelMap( + split="test", + data_root=self.data_root, + transform=self.predict_transform, + bands=self.bands, + spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally, + pad_image = self.pad_image, + truncate_image = self.truncate_image, **self.kwargs, ) diff --git a/terratorch/datamodules/openearthmap.py b/terratorch/datamodules/openearthmap.py index 613a6425..c4869ef3 100644 --- a/terratorch/datamodules/openearthmap.py +++ b/terratorch/datamodules/openearthmap.py @@ -29,20 +29,22 @@ def __init__( train_transform: A.Compose | None | list[A.BasicTransform] = None, val_transform: A.Compose | None | list[A.BasicTransform] = None, test_transform: A.Compose | None | list[A.BasicTransform] = None, + predict_transform: A.Compose | None | list[A.BasicTransform] = None, aug: AugmentationSequential = None, **kwargs: Any ) -> None: super().__init__(OpenEarthMapNonGeo, batch_size, num_workers, **kwargs) - + bands = kwargs.get("bands", OpenEarthMapNonGeo.all_band_names) self.means = torch.tensor([MEANS[b] for b in bands]) self.stds = torch.tensor([STDS[b] for b in bands]) self.train_transform = wrap_in_compose_is_list(train_transform) self.val_transform = wrap_in_compose_is_list(val_transform) self.test_transform = wrap_in_compose_is_list(test_transform) + self.predict_transform = wrap_in_compose_is_list(predict_transform) self.data_root = data_root self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug - + def setup(self, stage: str) -> None: if stage in ["fit"]: self.train_dataset = self.dataset_class( @@ -55,4 +57,8 @@ def setup(self, stage: str) -> None: if stage in ["test"]: self.test_dataset = self.dataset_class( split="test",data_root=self.data_root, transform=self.test_transform, **self.kwargs - ) \ No newline at end of file + ) + if stage in ["predict"]: + self.predict_dataset = self.dataset_class( + split="test",data_root=self.data_root, transform=self.predict_transform, **self.kwargs + ) diff --git a/terratorch/datamodules/pastis.py b/terratorch/datamodules/pastis.py index 76560851..7b3743c3 100644 --- a/terratorch/datamodules/pastis.py +++ b/terratorch/datamodules/pastis.py @@ -18,6 +18,7 @@ def __init__( train_transform: A.Compose | None | list[A.BasicTransform] = None, val_transform: A.Compose | None | list[A.BasicTransform] = None, test_transform: A.Compose | None | list[A.BasicTransform] = None, + predict_transform: A.Compose | None | list[A.BasicTransform] = None, **kwargs: Any, ) -> None: super().__init__( @@ -31,6 +32,7 @@ def __init__( self.train_transform = wrap_in_compose_is_list(train_transform) self.val_transform = wrap_in_compose_is_list(val_transform) self.test_transform = wrap_in_compose_is_list(test_transform) + self.predict_transform = wrap_in_compose_is_list(predict_transform) self.data_root = data_root self.kwargs = kwargs @@ -62,3 +64,12 @@ def setup(self, stage: str) -> None: pad_image=self.pad_image, **self.kwargs, ) + if stage in ["predict"]: + self.predict_dataset = PASTIS( + folds=[5], + data_root=self.data_root, + transform=self.predict_transform, + truncate_image=self.truncate_image, + pad_image=self.pad_image, + **self.kwargs, + ) diff --git a/terratorch/datamodules/sen1floods11.py b/terratorch/datamodules/sen1floods11.py index b9e2ff68..b64902fe 100644 --- a/terratorch/datamodules/sen1floods11.py +++ b/terratorch/datamodules/sen1floods11.py @@ -57,6 +57,7 @@ def __init__( train_transform: A.Compose | None | list[A.BasicTransform] = None, val_transform: A.Compose | None | list[A.BasicTransform] = None, test_transform: A.Compose | None | list[A.BasicTransform] = None, + predict_transform: A.Compose | None | list[A.BasicTransform] = None, drop_last: bool = True, constant_scale: float = 0.0001, no_data_replace: float | None = 0, @@ -73,6 +74,7 @@ def __init__( self.train_transform = wrap_in_compose_is_list(train_transform) self.val_transform = wrap_in_compose_is_list(val_transform) self.test_transform = wrap_in_compose_is_list(test_transform) + self.predict_transform = wrap_in_compose_is_list(predict_transform) self.aug = AugmentationSequential(K.Normalize(means, stds), data_keys=["image"]) self.drop_last = drop_last self.constant_scale = constant_scale @@ -114,6 +116,17 @@ def setup(self, stage: str) -> None: no_label_replace=self.no_label_replace, use_metadata=self.use_metadata, ) + if stage in ["predict"]: + self.predict_dataset = self.dataset_class( + split="test", + data_root=self.data_root, + transform=self.predict_transform, + bands=self.bands, + constant_scale=self.constant_scale, + no_data_replace=self.no_data_replace, + no_label_replace=self.no_label_replace, + use_metadata=self.use_metadata, + ) def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders. diff --git a/terratorch/datamodules/sen4agrinet.py b/terratorch/datamodules/sen4agrinet.py index 68652093..9fd67739 100644 --- a/terratorch/datamodules/sen4agrinet.py +++ b/terratorch/datamodules/sen4agrinet.py @@ -1,10 +1,10 @@ from typing import Any import albumentations as A # noqa: N812 -from torchgeo.datamodules import NonGeoDataModule from terratorch.datamodules.utils import wrap_in_compose_is_list from terratorch.datasets import Sen4AgriNet +from torchgeo.datamodules import NonGeoDataModule class Sen4AgriNetDataModule(NonGeoDataModule): @@ -17,10 +17,12 @@ def __init__( train_transform: A.Compose | None | list[A.BasicTransform] = None, val_transform: A.Compose | None | list[A.BasicTransform] = None, test_transform: A.Compose | None | list[A.BasicTransform] = None, - truncate_image: int | None = 6, - pad_image: int | None = 6, - spatial_interpolate_and_stack_temporally: bool = True, # noqa: FBT002, FBT001 + predict_transform: A.Compose | None | list[A.BasicTransform] = None, seed: int = 42, + scenario: str = "random", + requires_norm: bool = True, + binary_labels: bool = False, + linear_encoder: dict = None, **kwargs: Any, ) -> None: super().__init__( @@ -30,17 +32,18 @@ def __init__( **kwargs, ) self.bands = bands - self.truncate_image = truncate_image - self.pad_image = pad_image - self.spatial_interpolate_and_stack_temporally = spatial_interpolate_and_stack_temporally self.seed = seed self.train_transform = wrap_in_compose_is_list(train_transform) self.val_transform = wrap_in_compose_is_list(val_transform) self.test_transform = wrap_in_compose_is_list(test_transform) + self.predict_transform = wrap_in_compose_is_list(predict_transform) self.data_root = data_root + self.scenario = scenario + self.requires_norm = requires_norm + self.binary_labels = binary_labels + self.linear_encoder = linear_encoder self.kwargs = kwargs - def setup(self, stage: str) -> None: if stage in ["fit"]: self.train_dataset = Sen4AgriNet( @@ -48,10 +51,11 @@ def setup(self, stage: str) -> None: data_root=self.data_root, transform=self.train_transform, bands=self.bands, - truncate_image = self.truncate_image, - pad_image = self.pad_image, - spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally, - seed = self.seed, + seed=self.seed, + scenario=self.scenario, + requires_norm=self.requires_norm, + binary_labels=self.binary_labels, + linear_encoder=self.linear_encoder, **self.kwargs, ) if stage in ["fit", "validate"]: @@ -60,10 +64,11 @@ def setup(self, stage: str) -> None: data_root=self.data_root, transform=self.val_transform, bands=self.bands, - truncate_image = self.truncate_image, - pad_image = self.pad_image, - spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally, - seed = self.seed, + seed=self.seed, + scenario=self.scenario, + requires_norm=self.requires_norm, + binary_labels=self.binary_labels, + linear_encoder=self.linear_encoder, **self.kwargs, ) if stage in ["test"]: @@ -72,9 +77,23 @@ def setup(self, stage: str) -> None: data_root=self.data_root, transform=self.test_transform, bands=self.bands, - truncate_image = self.truncate_image, - pad_image = self.pad_image, - spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally, - seed = self.seed, + seed=self.seed, + scenario=self.scenario, + requires_norm=self.requires_norm, + binary_labels=self.binary_labels, + linear_encoder=self.linear_encoder, + **self.kwargs, + ) + if stage in ["predict"]: + self.predict_dataset = Sen4AgriNet( + split="test", + data_root=self.data_root, + transform=self.predict_transform, + bands=self.bands, + seed=self.seed, + scenario=self.scenario, + requires_norm=self.requires_norm, + binary_labels=self.binary_labels, + linear_encoder=self.linear_encoder, **self.kwargs, ) diff --git a/terratorch/tasks/classification_tasks.py b/terratorch/tasks/classification_tasks.py index 89974004..f91e1836 100644 --- a/terratorch/tasks/classification_tasks.py +++ b/terratorch/tasks/classification_tasks.py @@ -258,7 +258,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T Output predicted probabilities. """ x = batch["image"] - file_names = batch["filename"] + file_names = batch["filename"] if "filename" in batch else None other_keys = batch.keys() - {"image", "label", "filename"} rest = {k:batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) diff --git a/terratorch/tasks/regression_tasks.py b/terratorch/tasks/regression_tasks.py index 29bbc00f..bbc1dd48 100644 --- a/terratorch/tasks/regression_tasks.py +++ b/terratorch/tasks/regression_tasks.py @@ -348,7 +348,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T Output predicted probabilities. """ x = batch["image"] - file_names = batch["filename"] + file_names = batch["filename"] if "filename" in batch else None other_keys = batch.keys() - {"image", "mask", "filename"} rest = {k:batch[k] for k in other_keys} diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index 819a424a..2fdfdc0c 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -330,7 +330,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T Output predicted probabilities. """ x = batch["image"] - file_names = batch["filename"] + file_names = batch["filename"] if "filename" in batch else None other_keys = batch.keys() - {"image", "mask", "filename"} rest = {k: batch[k] for k in other_keys}