diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 406f686efe7324..2d5b680d587891 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -25,7 +25,7 @@ import os from contextlib import contextmanager from datetime import timedelta -from typing import Any, Dict, Generator, Iterable, List, Optional, Union +from typing import Any, Dict, Generator, Iterable, List, Optional, Union, Callable from weakref import proxy import torch @@ -127,6 +127,7 @@ def __init__( sync_batchnorm: bool = False, reload_dataloaders_every_n_epochs: int = 0, default_root_dir: Optional[_PATH] = None, + compile_fn: Optional[Callable] = None ) -> None: r"""Customize every aspect of training via flags. @@ -468,6 +469,8 @@ def __init__( self.should_stop = False self.state = TrainerState() + self.compile_fn = compile_fn + # configure profiler setup._init_profiler(self, profiler) @@ -956,6 +959,10 @@ def _run( # strategy will configure model and move it to the device self.strategy.setup(self) + # compile if compile_fn provided after configured strategy + if self.compile_fn is not None: + self.strategy.model = self.compile_fn(self.strategy.model) + # hook if self.state.fn == TrainerFn.FITTING: call._call_callback_hooks(self, "on_fit_start")