Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiTaskWrapper doesn't support being logged as a metric #2193

Closed
Ilykuleshov opened this issue Oct 30, 2023 · 1 comment · Fixed by #2213
Closed

MultiTaskWrapper doesn't support being logged as a metric #2193

Ilykuleshov opened this issue Oct 30, 2023 · 1 comment · Fixed by #2213
Assignees
Labels
enhancement New feature or request
Milestone

Comments

@Ilykuleshov
Copy link

🚀 Feature

Make MultiTaskWrapper (and possibly other wrappers) support logging in pytorch_lightning.LightningModule. I expected it would just work with self.log_dict, but it doesn't, since it's subclassed from Metric. It would be great if it had MetricCollection-like functionality.

Motivation

Better compatibility with pytorch_lightning is always good; also, this standardizes the APIs of different torchmetrics tools.

Pitch

TorchMetrics's class metrics are great at standardizing metric calculation, which in most scenarios allows to share boilerplate code, such as logging metric in each step (see Additional context for an example), between modules. This stems from torchmetrics' compatibility with pytorch_lightning. Probably, most people using this library, use it in conjunction with lightning, so increasing compatibility with lightning will definitely be an improvement.

Alternatives

Of course, one could use MultiTaskWrapper in the way outlined in the official example, calculating the metric, getting a simple python dictionary, logging the dictionary. But this would have a few shortcomings:

  • The epoch-wise metrics could be wrong! E.g. if one uses a classification metric with average="macro", logging each batch separately and then averaging over the batches to get the epoch-level metric is just wrong. One would need to manually save all targets and predictions, and calculate the metric in the on_*_end hook.
  • This makes room for greater confusion, I personally expected log_dict to work, and only after looking at the source code did I realize that it can't! There is not a word about pytorch_lightning in the docs, so one expects it to work as do the other metrics!

Additional context

Let me give an example of removing shared boilerplate logic from submodules. This could be the abstract parent class with the boilerplate code, which all submodules share.

    def training_step(
        self, batch, batch_idx: int
    ):
        preds, target = self.shared_step(batch, batch_idx)
        train_loss = self.loss(preds, target)
        self.train_metrics(preds, target)

        self.log("train_loss", train_loss, on_epoch=True)
        self.log_dict(self.train_metrics, on_epoch=True) # type: ignore

        return train_loss

    def validation_step(
        self, batch, batch_idx: int
    ):
        preds, target = self.shared_step(batch)
        val_loss = self.loss(preds, target)
        self.val_metrics(preds, target)

        self.log("val_loss", val_loss)
        self.log_dict(self.val_metrics) # type: ignore

    def test_step(
        self, batch, batch_idx: int
    ):
        preds, target = self.shared_step(batch, batch_idx)

        self.test_metrics(preds, target)
        self.log_dict(self.test_metrics, on_step=False, on_epoch=True) # type: ignore

And with that all a child class has to do is specify metrics and loss in init:

        def loss(probs, target):
            return nn.functional.nll_loss(
                torch.log(probs), target, ignore_index=pad_value
            )

        metrics = MetricCollection(
            {
                "AUROC": AUROC(
                    task="multiclass",
                    num_classes=num_types,
                    ignore_index=pad_value,
                    average="macro"
                ),
                "PR-AUC": AveragePrecision(
                    task="multiclass",
                    num_classes=num_types,
                    ignore_index=pad_value,
                    average="macro"
                ),
                "Accuracy": Accuracy(
                    task="multiclass",
                    num_classes=num_types,
                    ignore_index=pad_value,
                    average="macro"
                ),
                "F1Score": F1Score(
                    task="multiclass",
                    num_classes=num_types,
                    ignore_index=pad_value,
                    average="macro"
                ),
            }
        )

        super().__init__(
            loss=loss,
            metrics=metrics,
        )

and overload the shared_step to get the specific prediction logic. This would even work with different multitask modules, if only MultiTaskWrapper supported this behaviour.

@Ilykuleshov Ilykuleshov added the enhancement New feature or request label Oct 30, 2023
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants