Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add compile support in SupervisedTrainer and SupervisedEvaluator #7375

Merged
merged 25 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch.utils.data import DataLoader

from monai.config import IgniteInfo, KeysCollection
from monai.data import MetaTensor
from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch
from monai.engines.workflow import Workflow
from monai.inferers import Inferer, SimpleInferer
Expand All @@ -25,7 +26,7 @@
from monai.utils import ForwardMode, ensure_tuple, min_version, optional_import
from monai.utils.enums import CommonKeys as Keys
from monai.utils.enums import EngineStatsKeys as ESKeys
from monai.utils.module import look_up_option
from monai.utils.module import look_up_option, pytorch_after

if TYPE_CHECKING:
from ignite.engine import Engine, EventEnum
Expand Down Expand Up @@ -213,6 +214,10 @@ class SupervisedEvaluator(Evaluator):
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
compile: whether to use compile, default is False.
If set to True, the inputs will be converted to `torch.Tensor` internally.
compile_kwargs: dict of the args for `torch.compile()` API, for more details:
https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile.

"""

Expand All @@ -238,6 +243,8 @@ def __init__(
decollate: bool = True,
to_kwargs: dict | None = None,
amp_kwargs: dict | None = None,
compile: bool = False,
compile_kwargs: dict = {},
) -> None:
super().__init__(
device=device,
Expand All @@ -259,8 +266,12 @@ def __init__(
to_kwargs=to_kwargs,
amp_kwargs=amp_kwargs,
)

self.network = network
if compile:
assert pytorch_after(2, 1)
self.network = torch.compile(network, **compile_kwargs)
else:
self.network = network
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
self.compile = compile
self.inferer = SimpleInferer() if inferer is None else inferer

def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Tensor]) -> dict:
Expand Down Expand Up @@ -288,7 +299,9 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten
kwargs: dict = {}
else:
inputs, targets, args, kwargs = batch

if self.compile:
inputs = torch.Tensor(inputs) if isinstance(inputs, MetaTensor) else inputs
targets = torch.Tensor(targets) if isinstance(targets, MetaTensor) else targets
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
# put iteration outputs into engine.state
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
# execute forward computation
Expand Down
20 changes: 17 additions & 3 deletions monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
from torch.utils.data import DataLoader

from monai.config import IgniteInfo
from monai.data import MetaTensor
from monai.engines.utils import IterationEvents, default_make_latent, default_metric_cmp_fn, default_prepare_batch
from monai.engines.workflow import Workflow
from monai.inferers import Inferer, SimpleInferer
from monai.transforms import Transform
from monai.utils import GanKeys, min_version, optional_import
from monai.utils.enums import CommonKeys as Keys
from monai.utils.enums import EngineStatsKeys as ESKeys
from monai.utils.module import pytorch_after

if TYPE_CHECKING:
from ignite.engine import Engine, EventEnum
Expand Down Expand Up @@ -125,7 +127,10 @@ class SupervisedTrainer(Trainer):
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.

compile: whether to use compile, default is False.
If set to True, the inputs will be converted to `torch.Tensor` internally.
compile_kwargs: dict of the args for `torch.compile()` API, for more details:
https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile.
"""

def __init__(
Expand Down Expand Up @@ -153,6 +158,8 @@ def __init__(
optim_set_to_none: bool = False,
to_kwargs: dict | None = None,
amp_kwargs: dict | None = None,
compile: bool = False,
compile_kwargs: dict = {},
) -> None:
super().__init__(
device=device,
Expand All @@ -174,8 +181,12 @@ def __init__(
to_kwargs=to_kwargs,
amp_kwargs=amp_kwargs,
)

self.network = network
if compile:
assert pytorch_after(2, 1)
self.network = torch.compile(network, **compile_kwargs)
else:
self.network = network
self.compile = compile
self.optimizer = optimizer
self.loss_function = loss_function
self.inferer = SimpleInferer() if inferer is None else inferer
Expand Down Expand Up @@ -207,6 +218,9 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tenso
kwargs: dict = {}
else:
inputs, targets, args, kwargs = batch
if self.compile:
inputs = torch.Tensor(inputs) if isinstance(inputs, MetaTensor) else inputs
targets = torch.Tensor(targets) if isinstance(targets, MetaTensor) else targets
# put iteration outputs into engine.state
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}

Expand Down
Loading