Skip to content

Commit

Permalink
Merge pull request #337 from PedroConrado/feature/predict-datamodules
Browse files Browse the repository at this point in the history
adds predict to datamodules
Joao-L-S-Almeida authored Jan 9, 2025
2 parents 75d32cd + 35f39bc commit 835d4b2
Showing 16 changed files with 192 additions and 36 deletions.
20 changes: 20 additions & 0 deletions terratorch/datamodules/biomassters.py
Original file line number Diff line number Diff line change
@@ -74,6 +74,7 @@ def __init__(
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
predict_transform: A.Compose | None | list[A.BasicTransform] = None,
aug: AugmentationSequential = None,
drop_last: bool = True,
sensors: Sequence[str] = ["S1", "S2"],
@@ -107,6 +108,7 @@ def __init__(
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.predict_transform = wrap_in_compose_is_list(predict_transform)
if len(sensors) == 1:
self.aug = Normalize(self.means[sensors[0]], self.stds[sensors[0]]) if aug is None else aug
else:
@@ -176,6 +178,24 @@ def setup(self, stage: str) -> None:
seed=self.seed,
use_four_frames=self.use_four_frames,
)
if stage in ["predict"]:
self.predict_dataset = self.dataset_class(
split="test",
root=self.data_root,
transform=self.predict_transform,
bands=self.bands,
mask_mean=self.mask_mean,
mask_std=self.mask_std,
sensors=self.sensors,
as_time_series=self.as_time_series,
metadata_filename=self.metadata_filename,
max_cloud_percentage=self.max_cloud_percentage,
max_red_mean=self.max_red_mean,
include_corrupt=self.include_corrupt,
subset=self.subset,
seed=self.seed,
use_four_frames=self.use_four_frames,
)

def _dataloader_factory(self, split: str):
dataset = self._valid_attribute(f"{split}_dataset", "dataset")
13 changes: 13 additions & 0 deletions terratorch/datamodules/burn_intensity.py
Original file line number Diff line number Diff line change
@@ -37,6 +37,7 @@ def __init__(
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
predict_transform: A.Compose | None | list[A.BasicTransform] = None,
use_full_data: bool = True,
no_data_replace: float | None = 0.0001,
no_label_replace: int | None = -1,
@@ -52,6 +53,7 @@ def __init__(
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.predict_transform = wrap_in_compose_is_list(predict_transform)
self.aug = NormalizeWithTimesteps(means, stds)
self.use_full_data = use_full_data
self.no_data_replace = no_data_replace
@@ -92,3 +94,14 @@ def setup(self, stage: str) -> None:
no_label_replace=self.no_label_replace,
use_metadata=self.use_metadata,
)
if stage in ["predict"]:
self.predict_dataset = self.dataset_class(
split="val",
data_root=self.data_root,
transform=self.predict_transform,
bands=self.bands,
use_full_data=self.use_full_data,
no_data_replace=self.no_data_replace,
no_label_replace=self.no_label_replace,
use_metadata=self.use_metadata,
)
13 changes: 13 additions & 0 deletions terratorch/datamodules/carbonflux.py
Original file line number Diff line number Diff line change
@@ -50,6 +50,7 @@ def __init__(
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
predict_transform: A.Compose | None | list[A.BasicTransform] = None,
aug: AugmentationSequential = None,
no_data_replace: float | None = 0.0001,
use_metadata: bool = False,
@@ -72,6 +73,7 @@ def __init__(
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.predict_transform = wrap_in_compose_is_list(predict_transform)
self.aug = MultimodalNormalize(means, stds) if aug is None else aug
self.no_data_replace = no_data_replace
self.use_metadata = use_metadata
@@ -110,3 +112,14 @@ def setup(self, stage: str) -> None:
no_data_replace=self.no_data_replace,
use_metadata=self.use_metadata,
)
if stage in ["predict"]:
self.predict_dataset = self.dataset_class(
split="test",
data_root=self.data_root,
transform=self.predict_transform,
bands=self.bands,
gpp_mean=self.mask_means,
gpp_std=self.mask_std,
no_data_replace=self.no_data_replace,
use_metadata=self.use_metadata,
)
12 changes: 12 additions & 0 deletions terratorch/datamodules/fire_scars.py
Original file line number Diff line number Diff line change
@@ -46,6 +46,7 @@ def __init__(
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
predict_transform: A.Compose | None | list[A.BasicTransform] = None,
drop_last: bool = True,
no_data_replace: float | None = 0,
no_label_replace: int | None = -1,
@@ -61,6 +62,7 @@ def __init__(
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.predict_transform = wrap_in_compose_is_list(predict_transform)
self.aug = AugmentationSequential(K.Normalize(means, stds), data_keys=["image"])
self.drop_last = drop_last
self.no_data_replace = no_data_replace
@@ -98,6 +100,16 @@ def setup(self, stage: str) -> None:
no_label_replace=self.no_label_replace,
use_metadata=self.use_metadata,
)
if stage in ["predict"]:
self.predict_dataset = self.dataset_class(
split="val",
data_root=self.data_root,
transform=self.predict_transform,
bands=self.bands,
no_data_replace=self.no_data_replace,
no_label_replace=self.no_label_replace,
use_metadata=self.use_metadata,
)

def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
"""Implement one or more PyTorch DataLoaders.
12 changes: 12 additions & 0 deletions terratorch/datamodules/forestnet.py
Original file line number Diff line number Diff line change
@@ -42,6 +42,7 @@ def __init__(
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
predict_transform: A.Compose | None | list[A.BasicTransform] = None,
fraction: float = 1.0,
aug: AugmentationSequential = None,
use_metadata: bool = False,
@@ -57,6 +58,7 @@ def __init__(
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.predict_transform = wrap_in_compose_is_list(predict_transform)
self.aug = Normalize(self.means, self.stds) if aug is None else aug
self.fraction = fraction
self.use_metadata = use_metadata
@@ -92,3 +94,13 @@ def setup(self, stage: str) -> None:
fraction=self.fraction,
use_metadata=self.use_metadata,
)
if stage in ["predict"]:
self.predict_dataset = self.dataset_class(
split="test",
data_root=self.data_root,
label_map=self.label_map,
transform=self.predict_transform,
bands=self.bands,
fraction=self.fraction,
use_metadata=self.use_metadata,
)
11 changes: 11 additions & 0 deletions terratorch/datamodules/geobench_data_module.py
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ def __init__(
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
predict_transform: A.Compose | None | list[A.BasicTransform] = None,
aug: AugmentationSequential = None,
partition: str = "default",
**kwargs: Any,
@@ -35,6 +36,7 @@ def __init__(
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.predict_transform = wrap_in_compose_is_list(predict_transform)
self.data_root = data_root
self.partition = partition
self.aug = (
@@ -69,3 +71,12 @@ def setup(self, stage: str) -> None:
bands=self.bands,
**self.kwargs,
)
if stage in ["predict"]:
self.predict_dataset = self.dataset_class(
split="test",
data_root=self.data_root,
transform=self.predict_transform,
partition=self.partition,
bands=self.bands,
**self.kwargs,
)
9 changes: 9 additions & 0 deletions terratorch/datamodules/landslide4sense.py
Original file line number Diff line number Diff line change
@@ -56,6 +56,7 @@ def __init__(
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
predict_transform: A.Compose | None | list[A.BasicTransform] = None,
aug: AugmentationSequential = None,
**kwargs: Any,
) -> None:
@@ -68,6 +69,7 @@ def __init__(
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.predict_transform = wrap_in_compose_is_list(predict_transform)
self.aug = (
AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug
)
@@ -94,3 +96,10 @@ def setup(self, stage: str) -> None:
transform=self.test_transform,
bands=self.bands
)
if stage in ["predict"]:
self.predict_dataset = self.dataset_class(
split="test",
data_root=self.data_root,
transform=self.predict_transform,
bands=self.bands
)
14 changes: 14 additions & 0 deletions terratorch/datamodules/multi_temporal_crop_classification.py
Original file line number Diff line number Diff line change
@@ -41,6 +41,7 @@ def __init__(
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
predict_transform: A.Compose | None | list[A.BasicTransform] = None,
drop_last: bool = True,
no_data_replace: float | None = 0,
no_label_replace: int | None = -1,
@@ -58,6 +59,7 @@ def __init__(
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.predict_transform = wrap_in_compose_is_list(predict_transform)
self.aug = Normalize(self.means, self.stds)
self.drop_last = drop_last
self.no_data_replace = no_data_replace
@@ -103,6 +105,18 @@ def setup(self, stage: str) -> None:
reduce_zero_label = self.reduce_zero_label,
use_metadata=self.use_metadata,
)
if stage in ["predict"]:
self.predict_dataset = self.dataset_class(
split="val",
data_root=self.data_root,
transform=self.predict_transform,
bands=self.bands,
no_data_replace=self.no_data_replace,
no_label_replace=self.no_label_replace,
expand_temporal_dimension = self.expand_temporal_dimension,
reduce_zero_label = self.reduce_zero_label,
use_metadata=self.use_metadata,
)

def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
"""Implement one or more PyTorch DataLoaders.
23 changes: 13 additions & 10 deletions terratorch/datamodules/open_sentinel_map.py
Original file line number Diff line number Diff line change
@@ -17,11 +17,10 @@ def __init__(
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
predict_transform: A.Compose | None | list[A.BasicTransform] = None,
spatial_interpolate_and_stack_temporally: bool = True, # noqa: FBT001, FBT002
pad_image: int | None = None,
truncate_image: int | None = None,
target: int = 0,
pick_random_pair: bool = True, # noqa: FBT002, FBT001
**kwargs: Any,
) -> None:
super().__init__(
@@ -34,11 +33,10 @@ def __init__(
self.spatial_interpolate_and_stack_temporally = spatial_interpolate_and_stack_temporally
self.pad_image = pad_image
self.truncate_image = truncate_image
self.target = target
self.pick_random_pair = pick_random_pair
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.predict_transform = wrap_in_compose_is_list(predict_transform)
self.data_root = data_root
self.kwargs = kwargs

@@ -52,8 +50,6 @@ def setup(self, stage: str) -> None:
spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
pad_image = self.pad_image,
truncate_image = self.truncate_image,
target = self.target,
pick_random_pair = self.pick_random_pair,
**self.kwargs,
)
if stage in ["fit", "validate"]:
@@ -65,8 +61,6 @@ def setup(self, stage: str) -> None:
spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
pad_image = self.pad_image,
truncate_image = self.truncate_image,
target = self.target,
pick_random_pair = self.pick_random_pair,
**self.kwargs,
)
if stage in ["test"]:
@@ -78,7 +72,16 @@ def setup(self, stage: str) -> None:
spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
pad_image = self.pad_image,
truncate_image = self.truncate_image,
target = self.target,
pick_random_pair = self.pick_random_pair,
**self.kwargs,
)
if stage in ["predict"]:
self.predict_dataset = OpenSentinelMap(
split="test",
data_root=self.data_root,
transform=self.predict_transform,
bands=self.bands,
spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
pad_image = self.pad_image,
truncate_image = self.truncate_image,
**self.kwargs,
)
12 changes: 9 additions & 3 deletions terratorch/datamodules/openearthmap.py
Original file line number Diff line number Diff line change
@@ -29,20 +29,22 @@ def __init__(
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
predict_transform: A.Compose | None | list[A.BasicTransform] = None,
aug: AugmentationSequential = None,
**kwargs: Any
) -> None:
super().__init__(OpenEarthMapNonGeo, batch_size, num_workers, **kwargs)

bands = kwargs.get("bands", OpenEarthMapNonGeo.all_band_names)
self.means = torch.tensor([MEANS[b] for b in bands])
self.stds = torch.tensor([STDS[b] for b in bands])
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.predict_transform = wrap_in_compose_is_list(predict_transform)
self.data_root = data_root
self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug

def setup(self, stage: str) -> None:
if stage in ["fit"]:
self.train_dataset = self.dataset_class(
@@ -55,4 +57,8 @@ def setup(self, stage: str) -> None:
if stage in ["test"]:
self.test_dataset = self.dataset_class(
split="test",data_root=self.data_root, transform=self.test_transform, **self.kwargs
)
)
if stage in ["predict"]:
self.predict_dataset = self.dataset_class(
split="test",data_root=self.data_root, transform=self.predict_transform, **self.kwargs
)
11 changes: 11 additions & 0 deletions terratorch/datamodules/pastis.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ def __init__(
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
predict_transform: A.Compose | None | list[A.BasicTransform] = None,
**kwargs: Any,
) -> None:
super().__init__(
@@ -31,6 +32,7 @@ def __init__(
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.predict_transform = wrap_in_compose_is_list(predict_transform)
self.data_root = data_root
self.kwargs = kwargs

@@ -62,3 +64,12 @@ def setup(self, stage: str) -> None:
pad_image=self.pad_image,
**self.kwargs,
)
if stage in ["predict"]:
self.predict_dataset = PASTIS(
folds=[5],
data_root=self.data_root,
transform=self.predict_transform,
truncate_image=self.truncate_image,
pad_image=self.pad_image,
**self.kwargs,
)
13 changes: 13 additions & 0 deletions terratorch/datamodules/sen1floods11.py
Original file line number Diff line number Diff line change
@@ -57,6 +57,7 @@ def __init__(
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
predict_transform: A.Compose | None | list[A.BasicTransform] = None,
drop_last: bool = True,
constant_scale: float = 0.0001,
no_data_replace: float | None = 0,
@@ -73,6 +74,7 @@ def __init__(
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.predict_transform = wrap_in_compose_is_list(predict_transform)
self.aug = AugmentationSequential(K.Normalize(means, stds), data_keys=["image"])
self.drop_last = drop_last
self.constant_scale = constant_scale
@@ -114,6 +116,17 @@ def setup(self, stage: str) -> None:
no_label_replace=self.no_label_replace,
use_metadata=self.use_metadata,
)
if stage in ["predict"]:
self.predict_dataset = self.dataset_class(
split="test",
data_root=self.data_root,
transform=self.predict_transform,
bands=self.bands,
constant_scale=self.constant_scale,
no_data_replace=self.no_data_replace,
no_label_replace=self.no_label_replace,
use_metadata=self.use_metadata,
)

def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
"""Implement one or more PyTorch DataLoaders.
59 changes: 39 additions & 20 deletions terratorch/datamodules/sen4agrinet.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any

import albumentations as A # noqa: N812
from torchgeo.datamodules import NonGeoDataModule

from terratorch.datamodules.utils import wrap_in_compose_is_list
from terratorch.datasets import Sen4AgriNet
from torchgeo.datamodules import NonGeoDataModule


class Sen4AgriNetDataModule(NonGeoDataModule):
@@ -17,10 +17,12 @@ def __init__(
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
truncate_image: int | None = 6,
pad_image: int | None = 6,
spatial_interpolate_and_stack_temporally: bool = True, # noqa: FBT002, FBT001
predict_transform: A.Compose | None | list[A.BasicTransform] = None,
seed: int = 42,
scenario: str = "random",
requires_norm: bool = True,
binary_labels: bool = False,
linear_encoder: dict = None,
**kwargs: Any,
) -> None:
super().__init__(
@@ -30,28 +32,30 @@ def __init__(
**kwargs,
)
self.bands = bands
self.truncate_image = truncate_image
self.pad_image = pad_image
self.spatial_interpolate_and_stack_temporally = spatial_interpolate_and_stack_temporally
self.seed = seed
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.predict_transform = wrap_in_compose_is_list(predict_transform)
self.data_root = data_root
self.scenario = scenario
self.requires_norm = requires_norm
self.binary_labels = binary_labels
self.linear_encoder = linear_encoder
self.kwargs = kwargs


def setup(self, stage: str) -> None:
if stage in ["fit"]:
self.train_dataset = Sen4AgriNet(
split="train",
data_root=self.data_root,
transform=self.train_transform,
bands=self.bands,
truncate_image = self.truncate_image,
pad_image = self.pad_image,
spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
seed = self.seed,
seed=self.seed,
scenario=self.scenario,
requires_norm=self.requires_norm,
binary_labels=self.binary_labels,
linear_encoder=self.linear_encoder,
**self.kwargs,
)
if stage in ["fit", "validate"]:
@@ -60,10 +64,11 @@ def setup(self, stage: str) -> None:
data_root=self.data_root,
transform=self.val_transform,
bands=self.bands,
truncate_image = self.truncate_image,
pad_image = self.pad_image,
spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
seed = self.seed,
seed=self.seed,
scenario=self.scenario,
requires_norm=self.requires_norm,
binary_labels=self.binary_labels,
linear_encoder=self.linear_encoder,
**self.kwargs,
)
if stage in ["test"]:
@@ -72,9 +77,23 @@ def setup(self, stage: str) -> None:
data_root=self.data_root,
transform=self.test_transform,
bands=self.bands,
truncate_image = self.truncate_image,
pad_image = self.pad_image,
spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
seed = self.seed,
seed=self.seed,
scenario=self.scenario,
requires_norm=self.requires_norm,
binary_labels=self.binary_labels,
linear_encoder=self.linear_encoder,
**self.kwargs,
)
if stage in ["predict"]:
self.predict_dataset = Sen4AgriNet(
split="test",
data_root=self.data_root,
transform=self.predict_transform,
bands=self.bands,
seed=self.seed,
scenario=self.scenario,
requires_norm=self.requires_norm,
binary_labels=self.binary_labels,
linear_encoder=self.linear_encoder,
**self.kwargs,
)
2 changes: 1 addition & 1 deletion terratorch/tasks/classification_tasks.py
Original file line number Diff line number Diff line change
@@ -258,7 +258,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T
Output predicted probabilities.
"""
x = batch["image"]
file_names = batch["filename"]
file_names = batch["filename"] if "filename" in batch else None
other_keys = batch.keys() - {"image", "label", "filename"}
rest = {k:batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)
2 changes: 1 addition & 1 deletion terratorch/tasks/regression_tasks.py
Original file line number Diff line number Diff line change
@@ -348,7 +348,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T
Output predicted probabilities.
"""
x = batch["image"]
file_names = batch["filename"]
file_names = batch["filename"] if "filename" in batch else None
other_keys = batch.keys() - {"image", "mask", "filename"}
rest = {k:batch[k] for k in other_keys}

2 changes: 1 addition & 1 deletion terratorch/tasks/segmentation_tasks.py
Original file line number Diff line number Diff line change
@@ -330,7 +330,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T
Output predicted probabilities.
"""
x = batch["image"]
file_names = batch["filename"]
file_names = batch["filename"] if "filename" in batch else None
other_keys = batch.keys() - {"image", "mask", "filename"}

rest = {k: batch[k] for k in other_keys}

0 comments on commit 835d4b2

Please sign in to comment.