You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Make MultiTaskWrapper (and possibly other wrappers) support logging in pytorch_lightning.LightningModule. I expected it would just work with self.log_dict, but it doesn't, since it's subclassed from Metric. It would be great if it had MetricCollection-like functionality.
Motivation
Better compatibility with pytorch_lightning is always good; also, this standardizes the APIs of different torchmetrics tools.
Pitch
TorchMetrics's class metrics are great at standardizing metric calculation, which in most scenarios allows to share boilerplate code, such as logging metric in each step (see Additional context for an example), between modules. This stems from torchmetrics' compatibility with pytorch_lightning. Probably, most people using this library, use it in conjunction with lightning, so increasing compatibility with lightning will definitely be an improvement.
Alternatives
Of course, one could use MultiTaskWrapper in the way outlined in the official example, calculating the metric, getting a simple python dictionary, logging the dictionary. But this would have a few shortcomings:
The epoch-wise metrics could be wrong! E.g. if one uses a classification metric with average="macro", logging each batch separately and then averaging over the batches to get the epoch-level metric is just wrong. One would need to manually save all targets and predictions, and calculate the metric in the on_*_end hook.
This makes room for greater confusion, I personally expected log_dict to work, and only after looking at the source code did I realize that it can't! There is not a word about pytorch_lightning in the docs, so one expects it to work as do the other metrics!
Additional context
Let me give an example of removing shared boilerplate logic from submodules. This could be the abstract parent class with the boilerplate code, which all submodules share.
and overload the shared_step to get the specific prediction logic. This would even work with different multitask modules, if only MultiTaskWrapper supported this behaviour.
The text was updated successfully, but these errors were encountered:
🚀 Feature
Make MultiTaskWrapper (and possibly other wrappers) support logging in pytorch_lightning.LightningModule. I expected it would just work with self.log_dict, but it doesn't, since it's subclassed from Metric. It would be great if it had MetricCollection-like functionality.
Motivation
Better compatibility with pytorch_lightning is always good; also, this standardizes the APIs of different torchmetrics tools.
Pitch
TorchMetrics's class metrics are great at standardizing metric calculation, which in most scenarios allows to share boilerplate code, such as logging metric in each step (see Additional context for an example), between modules. This stems from torchmetrics' compatibility with pytorch_lightning. Probably, most people using this library, use it in conjunction with lightning, so increasing compatibility with lightning will definitely be an improvement.
Alternatives
Of course, one could use MultiTaskWrapper in the way outlined in the official example, calculating the metric, getting a simple python dictionary, logging the dictionary. But this would have a few shortcomings:
Additional context
Let me give an example of removing shared boilerplate logic from submodules. This could be the abstract parent class with the boilerplate code, which all submodules share.
And with that all a child class has to do is specify metrics and loss in init:
and overload the shared_step to get the specific prediction logic. This would even work with different multitask modules, if only MultiTaskWrapper supported this behaviour.
The text was updated successfully, but these errors were encountered: