diff --git a/terratorch/cli_tools.py b/terratorch/cli_tools.py index e0d523e1..171710f6 100644 --- a/terratorch/cli_tools.py +++ b/terratorch/cli_tools.py @@ -123,7 +123,7 @@ def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): # output_dir = self.output_dir if not os.path.exists(output_dir): - os.mkdir(output_dir) + os.makedirs(output_dir, exist_ok=True) for pred_batch, filename_batch in predictions: for prediction, file_name in zip(torch.unbind(pred_batch, dim=0), filename_batch, strict=False): @@ -467,4 +467,4 @@ def inference(self, file_path: Path) -> torch.Tensor: prediction, file_name = self.inference_on_dir( tmpdir, ) - return prediction.squeeze(0) + return prediction.squeeze(0) \ No newline at end of file