diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 119853d5c5..2c8dfe6b85 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -11,12 +11,14 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence import torch from torch.utils.data import DataLoader from monai.config import IgniteInfo, KeysCollection +from monai.data import MetaTensor from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer @@ -25,7 +27,7 @@ from monai.utils import ForwardMode, ensure_tuple, min_version, optional_import from monai.utils.enums import CommonKeys as Keys from monai.utils.enums import EngineStatsKeys as ESKeys -from monai.utils.module import look_up_option +from monai.utils.module import look_up_option, pytorch_after if TYPE_CHECKING: from ignite.engine import Engine, EventEnum @@ -213,6 +215,10 @@ class SupervisedEvaluator(Evaluator): `device`, `non_blocking`. amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. + compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to + `torch.Tensor` before forward pass, then converted back afterward with copied meta information. + compile_kwargs: dict of the args for `torch.compile()` API, for more details: + https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile. """ @@ -238,6 +244,8 @@ def __init__( decollate: bool = True, to_kwargs: dict | None = None, amp_kwargs: dict | None = None, + compile: bool = False, + compile_kwargs: dict | None = None, ) -> None: super().__init__( device=device, @@ -259,8 +267,16 @@ def __init__( to_kwargs=to_kwargs, amp_kwargs=amp_kwargs, ) - + if compile: + if pytorch_after(2, 1): + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + network = torch.compile(network, **compile_kwargs) # type: ignore[assignment] + else: + warnings.warn( + "Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done" + ) self.network = network + self.compile = compile self.inferer = SimpleInferer() if inferer is None else inferer def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Tensor]) -> dict: @@ -288,6 +304,24 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten kwargs: dict = {} else: inputs, targets, args, kwargs = batch + # FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026 + if self.compile: + inputs_meta, targets_meta, inputs_applied_operations, targets_applied_operations = None, None, None, None + if isinstance(inputs, MetaTensor): + warnings.warn( + "Will convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass." + ) + inputs, inputs_meta, inputs_applied_operations = ( + inputs.as_tensor(), + inputs.meta, + inputs.applied_operations, + ) + if isinstance(targets, MetaTensor): + targets, targets_meta, targets_applied_operations = ( + targets.as_tensor(), + targets.meta, + targets.applied_operations, + ) # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} @@ -298,6 +332,19 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) else: engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) + # copy back meta info + if self.compile: + if inputs_meta is not None: + engine.state.output[Keys.IMAGE] = MetaTensor( + inputs, meta=inputs_meta, applied_operations=inputs_applied_operations + ) + engine.state.output[Keys.PRED] = MetaTensor( + engine.state.output[Keys.PRED], meta=inputs_meta, applied_operations=inputs_applied_operations + ) + if targets_meta is not None: + engine.state.output[Keys.LABEL] = MetaTensor( + targets, meta=targets_meta, applied_operations=targets_applied_operations + ) engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 61b7028e11..f1513ea73b 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -11,6 +11,7 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence import torch @@ -18,6 +19,7 @@ from torch.utils.data import DataLoader from monai.config import IgniteInfo +from monai.data import MetaTensor from monai.engines.utils import IterationEvents, default_make_latent, default_metric_cmp_fn, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer @@ -25,6 +27,7 @@ from monai.utils import GanKeys, min_version, optional_import from monai.utils.enums import CommonKeys as Keys from monai.utils.enums import EngineStatsKeys as ESKeys +from monai.utils.module import pytorch_after if TYPE_CHECKING: from ignite.engine import Engine, EventEnum @@ -125,7 +128,10 @@ class SupervisedTrainer(Trainer): `device`, `non_blocking`. amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. - + compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to + `torch.Tensor` before forward pass, then converted back afterward with copied meta information. + compile_kwargs: dict of the args for `torch.compile()` API, for more details: + https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile. """ def __init__( @@ -153,6 +159,8 @@ def __init__( optim_set_to_none: bool = False, to_kwargs: dict | None = None, amp_kwargs: dict | None = None, + compile: bool = False, + compile_kwargs: dict | None = None, ) -> None: super().__init__( device=device, @@ -174,8 +182,16 @@ def __init__( to_kwargs=to_kwargs, amp_kwargs=amp_kwargs, ) - + if compile: + if pytorch_after(2, 1): + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + network = torch.compile(network, **compile_kwargs) # type: ignore[assignment] + else: + warnings.warn( + "Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done" + ) self.network = network + self.compile = compile self.optimizer = optimizer self.loss_function = loss_function self.inferer = SimpleInferer() if inferer is None else inferer @@ -207,6 +223,25 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tenso kwargs: dict = {} else: inputs, targets, args, kwargs = batch + # FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026 + if self.compile: + inputs_meta, targets_meta, inputs_applied_operations, targets_applied_operations = None, None, None, None + if isinstance(inputs, MetaTensor): + warnings.warn( + "Will convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass." + ) + inputs, inputs_meta, inputs_applied_operations = ( + inputs.as_tensor(), + inputs.meta, + inputs.applied_operations, + ) + if isinstance(targets, MetaTensor): + targets, targets_meta, targets_applied_operations = ( + targets.as_tensor(), + targets.meta, + targets.applied_operations, + ) + # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} @@ -231,6 +266,19 @@ def _compute_pred_loss(): engine.state.output[Keys.LOSS].backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) engine.optimizer.step() + # copy back meta info + if self.compile: + if inputs_meta is not None: + engine.state.output[Keys.IMAGE] = MetaTensor( + inputs, meta=inputs_meta, applied_operations=inputs_applied_operations + ) + engine.state.output[Keys.PRED] = MetaTensor( + engine.state.output[Keys.PRED], meta=inputs_meta, applied_operations=inputs_applied_operations + ) + if targets_meta is not None: + engine.state.output[Keys.LABEL] = MetaTensor( + targets, meta=targets_meta, applied_operations=targets_applied_operations + ) engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output