diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 2d5b680d587891..bab0dce90beb34 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -25,7 +25,7 @@ import os from contextlib import contextmanager from datetime import timedelta -from typing import Any, Dict, Generator, Iterable, List, Optional, Union, Callable +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union from weakref import proxy import torch @@ -127,7 +127,7 @@ def __init__( sync_batchnorm: bool = False, reload_dataloaders_every_n_epochs: int = 0, default_root_dir: Optional[_PATH] = None, - compile_fn: Optional[Callable] = None + compile_fn: Optional[Callable] = None, ) -> None: r"""Customize every aspect of training via flags. @@ -470,7 +470,7 @@ def __init__( self.state = TrainerState() self.compile_fn = compile_fn - + # configure profiler setup._init_profiler(self, profiler) @@ -962,7 +962,7 @@ def _run( # compile if compile_fn provided after configured strategy if self.compile_fn is not None: self.strategy.model = self.compile_fn(self.strategy.model) - + # hook if self.state.fn == TrainerFn.FITTING: call._call_callback_hooks(self, "on_fit_start")