diff --git a/terratorch/cli_tools.py b/terratorch/cli_tools.py index 171710f6..25960adb 100644 --- a/terratorch/cli_tools.py +++ b/terratorch/cli_tools.py @@ -398,6 +398,7 @@ def from_config( config_path: Path, checkpoint_path: Path | None = None, predict_dataset_bands: list[str] | None = None, + predict_output_bands: list[str] | None = None, ): """ Args: @@ -416,6 +417,10 @@ def from_config( arguments.extend([ "--data.init_args.predict_dataset_bands", "[" + ",".join(predict_dataset_bands) + "]",]) + if predict_output_bands is not None: + arguments.extend([ "--data.init_args.predict_output_bands", + "[" + ",".join(predict_output_bands) + "]",]) + cli = build_lightning_cli(arguments, run=False) trainer = cli.trainer # disable logging metrics @@ -467,4 +472,4 @@ def inference(self, file_path: Path) -> torch.Tensor: prediction, file_name = self.inference_on_dir( tmpdir, ) - return prediction.squeeze(0) \ No newline at end of file + return prediction.squeeze(0) diff --git a/terratorch/datamodules/generic_pixel_wise_data_module.py b/terratorch/datamodules/generic_pixel_wise_data_module.py index 434f7488..86e57c36 100644 --- a/terratorch/datamodules/generic_pixel_wise_data_module.py +++ b/terratorch/datamodules/generic_pixel_wise_data_module.py @@ -272,7 +272,7 @@ def setup(self, stage: str) -> None: self.predict_root, self.num_classes, dataset_bands=self.predict_dataset_bands, - output_bands=self.output_bands, + output_bands=self.predict_output_bands, constant_scale=self.constant_scale, rgb_indices=self.rgb_indices, transform=self.test_transform, @@ -335,6 +335,7 @@ def __init__( allow_substring_split_file: bool = True, dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, + predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, constant_scale: float = 1, rgb_indices: list[int] | None = None, @@ -426,6 +427,7 @@ def __init__( self.dataset_bands = dataset_bands self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands + self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands self.output_bands = output_bands self.rgb_indices = rgb_indices @@ -507,7 +509,7 @@ def setup(self, stage: str) -> None: self.predict_dataset = self.dataset_class( self.predict_root, dataset_bands=self.predict_dataset_bands, - output_bands=self.output_bands, + output_bands=self.predict_output_bands, constant_scale=self.constant_scale, rgb_indices=self.rgb_indices, transform=self.test_transform,