Skip to content

Commit

Permalink
Add compile_fn for Trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
mieshkiwrk authored Sep 10, 2024
1 parent bc3c9c5 commit fc3e0e0
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import os
from contextlib import contextmanager
from datetime import timedelta
from typing import Any, Dict, Generator, Iterable, List, Optional, Union
from typing import Any, Dict, Generator, Iterable, List, Optional, Union, Callable
from weakref import proxy

import torch
Expand Down Expand Up @@ -127,6 +127,7 @@ def __init__(
sync_batchnorm: bool = False,
reload_dataloaders_every_n_epochs: int = 0,
default_root_dir: Optional[_PATH] = None,
compile_fn: Optional[Callable] = None
) -> None:
r"""Customize every aspect of training via flags.
Expand Down Expand Up @@ -468,6 +469,8 @@ def __init__(
self.should_stop = False
self.state = TrainerState()

self.compile_fn = compile_fn

# configure profiler
setup._init_profiler(self, profiler)

Expand Down Expand Up @@ -956,6 +959,10 @@ def _run(
# strategy will configure model and move it to the device
self.strategy.setup(self)

# compile if compile_fn provided after configured strategy
if self.compile_fn is not None:
self.strategy.model = self.compile_fn(self.strategy.model)

# hook
if self.state.fn == TrainerFn.FITTING:
call._call_callback_hooks(self, "on_fit_start")
Expand Down

0 comments on commit fc3e0e0

Please sign in to comment.