Skip to content

Commit

Permalink
Add compile support in SupervisedTrainer and SupervisedEvaluator (
Browse files Browse the repository at this point in the history
#7375)

Fixes # .

### Description

Add `compile` support in `SupervisedTrainer` and `SupervisedEvaluator`.
Convert to `torch.Tensor` internally.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: YunLiu <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 19, 2024
1 parent 80be1c3 commit facf176
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 4 deletions.
51 changes: 49 additions & 2 deletions monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence

import torch
from torch.utils.data import DataLoader

from monai.config import IgniteInfo, KeysCollection
from monai.data import MetaTensor
from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch
from monai.engines.workflow import Workflow
from monai.inferers import Inferer, SimpleInferer
Expand All @@ -25,7 +27,7 @@
from monai.utils import ForwardMode, ensure_tuple, min_version, optional_import
from monai.utils.enums import CommonKeys as Keys
from monai.utils.enums import EngineStatsKeys as ESKeys
from monai.utils.module import look_up_option
from monai.utils.module import look_up_option, pytorch_after

if TYPE_CHECKING:
from ignite.engine import Engine, EventEnum
Expand Down Expand Up @@ -213,6 +215,10 @@ class SupervisedEvaluator(Evaluator):
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
`torch.Tensor` before forward pass, then converted back afterward with copied meta information.
compile_kwargs: dict of the args for `torch.compile()` API, for more details:
https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile.
"""

Expand All @@ -238,6 +244,8 @@ def __init__(
decollate: bool = True,
to_kwargs: dict | None = None,
amp_kwargs: dict | None = None,
compile: bool = False,
compile_kwargs: dict | None = None,
) -> None:
super().__init__(
device=device,
Expand All @@ -259,8 +267,16 @@ def __init__(
to_kwargs=to_kwargs,
amp_kwargs=amp_kwargs,
)

if compile:
if pytorch_after(2, 1):
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
else:
warnings.warn(
"Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done"
)
self.network = network
self.compile = compile
self.inferer = SimpleInferer() if inferer is None else inferer

def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Tensor]) -> dict:
Expand Down Expand Up @@ -288,6 +304,24 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten
kwargs: dict = {}
else:
inputs, targets, args, kwargs = batch
# FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026
if self.compile:
inputs_meta, targets_meta, inputs_applied_operations, targets_applied_operations = None, None, None, None
if isinstance(inputs, MetaTensor):
warnings.warn(
"Will convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass."
)
inputs, inputs_meta, inputs_applied_operations = (
inputs.as_tensor(),
inputs.meta,
inputs.applied_operations,
)
if isinstance(targets, MetaTensor):
targets, targets_meta, targets_applied_operations = (
targets.as_tensor(),
targets.meta,
targets.applied_operations,
)

# put iteration outputs into engine.state
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
Expand All @@ -298,6 +332,19 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
else:
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
# copy back meta info
if self.compile:
if inputs_meta is not None:
engine.state.output[Keys.IMAGE] = MetaTensor(
inputs, meta=inputs_meta, applied_operations=inputs_applied_operations
)
engine.state.output[Keys.PRED] = MetaTensor(
engine.state.output[Keys.PRED], meta=inputs_meta, applied_operations=inputs_applied_operations
)
if targets_meta is not None:
engine.state.output[Keys.LABEL] = MetaTensor(
targets, meta=targets_meta, applied_operations=targets_applied_operations
)
engine.fire_event(IterationEvents.FORWARD_COMPLETED)
engine.fire_event(IterationEvents.MODEL_COMPLETED)

Expand Down
52 changes: 50 additions & 2 deletions monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,23 @@

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence

import torch
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from monai.config import IgniteInfo
from monai.data import MetaTensor
from monai.engines.utils import IterationEvents, default_make_latent, default_metric_cmp_fn, default_prepare_batch
from monai.engines.workflow import Workflow
from monai.inferers import Inferer, SimpleInferer
from monai.transforms import Transform
from monai.utils import GanKeys, min_version, optional_import
from monai.utils.enums import CommonKeys as Keys
from monai.utils.enums import EngineStatsKeys as ESKeys
from monai.utils.module import pytorch_after

if TYPE_CHECKING:
from ignite.engine import Engine, EventEnum
Expand Down Expand Up @@ -125,7 +128,10 @@ class SupervisedTrainer(Trainer):
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
`torch.Tensor` before forward pass, then converted back afterward with copied meta information.
compile_kwargs: dict of the args for `torch.compile()` API, for more details:
https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile.
"""

def __init__(
Expand Down Expand Up @@ -153,6 +159,8 @@ def __init__(
optim_set_to_none: bool = False,
to_kwargs: dict | None = None,
amp_kwargs: dict | None = None,
compile: bool = False,
compile_kwargs: dict | None = None,
) -> None:
super().__init__(
device=device,
Expand All @@ -174,8 +182,16 @@ def __init__(
to_kwargs=to_kwargs,
amp_kwargs=amp_kwargs,
)

if compile:
if pytorch_after(2, 1):
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
else:
warnings.warn(
"Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done"
)
self.network = network
self.compile = compile
self.optimizer = optimizer
self.loss_function = loss_function
self.inferer = SimpleInferer() if inferer is None else inferer
Expand Down Expand Up @@ -207,6 +223,25 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tenso
kwargs: dict = {}
else:
inputs, targets, args, kwargs = batch
# FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026
if self.compile:
inputs_meta, targets_meta, inputs_applied_operations, targets_applied_operations = None, None, None, None
if isinstance(inputs, MetaTensor):
warnings.warn(
"Will convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass."
)
inputs, inputs_meta, inputs_applied_operations = (
inputs.as_tensor(),
inputs.meta,
inputs.applied_operations,
)
if isinstance(targets, MetaTensor):
targets, targets_meta, targets_applied_operations = (
targets.as_tensor(),
targets.meta,
targets.applied_operations,
)

# put iteration outputs into engine.state
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}

Expand All @@ -231,6 +266,19 @@ def _compute_pred_loss():
engine.state.output[Keys.LOSS].backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
engine.optimizer.step()
# copy back meta info
if self.compile:
if inputs_meta is not None:
engine.state.output[Keys.IMAGE] = MetaTensor(
inputs, meta=inputs_meta, applied_operations=inputs_applied_operations
)
engine.state.output[Keys.PRED] = MetaTensor(
engine.state.output[Keys.PRED], meta=inputs_meta, applied_operations=inputs_applied_operations
)
if targets_meta is not None:
engine.state.output[Keys.LABEL] = MetaTensor(
targets, meta=targets_meta, applied_operations=targets_applied_operations
)
engine.fire_event(IterationEvents.MODEL_COMPLETED)

return engine.state.output
Expand Down

0 comments on commit facf176

Please sign in to comment.