From 8de35006199ec9dc610223679564e0fcba43e501 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 16 Aug 2024 10:07:49 -0300 Subject: [PATCH 1/2] The CLI argument --data.init_args.predict_output_bands was missing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/cli_tools.py | 7 ++++++- terratorch/datamodules/generic_pixel_wise_data_module.py | 6 ++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/terratorch/cli_tools.py b/terratorch/cli_tools.py index 171710f6..25960adb 100644 --- a/terratorch/cli_tools.py +++ b/terratorch/cli_tools.py @@ -398,6 +398,7 @@ def from_config( config_path: Path, checkpoint_path: Path | None = None, predict_dataset_bands: list[str] | None = None, + predict_output_bands: list[str] | None = None, ): """ Args: @@ -416,6 +417,10 @@ def from_config( arguments.extend([ "--data.init_args.predict_dataset_bands", "[" + ",".join(predict_dataset_bands) + "]",]) + if predict_output_bands is not None: + arguments.extend([ "--data.init_args.predict_output_bands", + "[" + ",".join(predict_output_bands) + "]",]) + cli = build_lightning_cli(arguments, run=False) trainer = cli.trainer # disable logging metrics @@ -467,4 +472,4 @@ def inference(self, file_path: Path) -> torch.Tensor: prediction, file_name = self.inference_on_dir( tmpdir, ) - return prediction.squeeze(0) \ No newline at end of file + return prediction.squeeze(0) diff --git a/terratorch/datamodules/generic_pixel_wise_data_module.py b/terratorch/datamodules/generic_pixel_wise_data_module.py index 434f7488..ea6657b2 100644 --- a/terratorch/datamodules/generic_pixel_wise_data_module.py +++ b/terratorch/datamodules/generic_pixel_wise_data_module.py @@ -272,7 +272,7 @@ def setup(self, stage: str) -> None: self.predict_root, self.num_classes, dataset_bands=self.predict_dataset_bands, - output_bands=self.output_bands, + output_bands=self.predict_output_bands, constant_scale=self.constant_scale, rgb_indices=self.rgb_indices, transform=self.test_transform, @@ -335,6 +335,7 @@ def __init__( allow_substring_split_file: bool = True, dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, + predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, constant_scale: float = 1, rgb_indices: list[int] | None = None, @@ -426,6 +427,7 @@ def __init__( self.dataset_bands = dataset_bands self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands + self.predict_output_bands = predict_output_bands if predict_output_bands else dataset_bands self.output_bands = output_bands self.rgb_indices = rgb_indices @@ -507,7 +509,7 @@ def setup(self, stage: str) -> None: self.predict_dataset = self.dataset_class( self.predict_root, dataset_bands=self.predict_dataset_bands, - output_bands=self.output_bands, + output_bands=self.predict_output_bands, constant_scale=self.constant_scale, rgb_indices=self.rgb_indices, transform=self.test_transform, From f62921917a56f654e3939d5c57c613a04f33da3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Wed, 21 Aug 2024 12:04:22 -0300 Subject: [PATCH 2/2] output_bands MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/datamodules/generic_pixel_wise_data_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/terratorch/datamodules/generic_pixel_wise_data_module.py b/terratorch/datamodules/generic_pixel_wise_data_module.py index ea6657b2..86e57c36 100644 --- a/terratorch/datamodules/generic_pixel_wise_data_module.py +++ b/terratorch/datamodules/generic_pixel_wise_data_module.py @@ -427,7 +427,7 @@ def __init__( self.dataset_bands = dataset_bands self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands - self.predict_output_bands = predict_output_bands if predict_output_bands else dataset_bands + self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands self.output_bands = output_bands self.rgb_indices = rgb_indices