Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add compute_statistics subcommand #336

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

fmartiescofet
Copy link
Contributor

This PR add the compute_statistics subcommand to the terratorch CLI to compute the mean and std of a dataset.
It can be called with the same config file as any other subcommand: terratorch compute_statistic --config <file>.yaml

The kwargs in the method is required as the Lightning CLI requires the model parameter and we need to consume it, see here.

Signed-off-by: Francesc Marti Escofet <[email protected]>
Copy link
Collaborator

@blumenstiel blumenstiel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like how you integrated it into the cli tools. In general, I think it would be good to have this functionality in terratorch.utils so that it can also be called from python code. Maybe you could just add a new statistics.py and the Trainer calls this function with the train dataloader?

I am not sure how we could handle custom datasets which do not fit the expected pattern. Maybe there could be a second option to only pass a folder instead of a config?
Not sure if this could generalise better. What do you think could work well?

terratorch/cli_tools.py Show resolved Hide resolved
terratorch/cli_tools.py Outdated Show resolved Hide resolved
@fmartiescofet
Copy link
Contributor Author

fmartiescofet commented Dec 23, 2024

The main problem I had while integrating it into the CLI is the fact that an subcommand requires some parameters such as model... See here. My initial idea was to make the user pass only the dataset and create manually the dataloader inside the method, but this issue didn't allow it in an easy way (the user would always need to pass the model making it weird to have a config with a dataset and a model not being used). Therefore I thought that the best way was to let the user use the same config as they would use to call fit but leaving to the user the need to take care that there was no randomness such as random cropping.

At the end, if we also allow to pass a folder we would need to build the dataset in the method and require more parameters to build the generic dataset. With this way, the user can use their custom datasets (the only requirement is that the datamodule returns a dataloader on train_dataloader and that there should be no randomness). Maybe I can document this clearly so the user knows how to use the subcommand. What do you think?

I agree with saving the output in yaml and separating it into a new file, I'll do it

@blumenstiel
Copy link
Collaborator

Thanks for the changes! You are right, we can expect some checks by the user to make sure it's working correctly.
But I think we can just add one check before running .setup('fit') to handle at least for the generic datasets: We just need to replace train_transform with None.

if hasattr(datamodule, "train_transform"):
    datamodule. train_transform = None

What do you think?

About passing a dataset folder, I agree with you that it probably makes it more complicated. Users have to create a config anyway at some point, so they can just do it before computing the statistics.

Signed-off-by: Francesc Marti Escofet <[email protected]>
@fmartiescofet
Copy link
Contributor Author

@blumenstiel The issue with changing it in the datamodule is that it does not get changed in the dataset as it is already instantiated. I thought of overwriting the attribute in the dataset but then we would overwrite user defined transforms which they may want to compute the statistics, also we would need to keep the toTensor transformation as usually datasets return numpy arrays making it hard as different datasets may use different methods to convert to tensors.
I added some documentation about it, lmk what you think.

@blumenstiel
Copy link
Collaborator

@fmartiescofet The dataests are build in the .setup('fit'). If you overwirte it before, it should work.

Copy link
Collaborator

@blumenstiel blumenstiel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

@blumenstiel
Copy link
Collaborator

@Joao-L-S-Almeida @romeokienzler I think we can merge this one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants