forked from Clay-foundation/model
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
80 lines (70 loc) · 2.26 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
Command line interface to run the neural network model!
From the project root directory, do:
python trainer.py fit
References:
- https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli.html
- https://pytorch-lightning.medium.com/introducing-lightningcli-v2-supercharge-your-training-c070d43c7dd6
"""
from lightning.pytorch.callbacks import (
LearningRateMonitor, # noqa: F401
ModelCheckpoint,
)
from lightning.pytorch.cli import ArgsType, LightningCLI
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.plugins.io import AsyncCheckpointIO
from src.callbacks_wandb import ( # noqa: F401
LogIntermediatePredictions,
LogMAEReconstruction,
)
from src.datamodule import ClayDataModule, GeoTIFFDataPipeModule # noqa: F401
from src.model_clay import CLAYModule
from src.model_vit import ViTLitModule # noqa: F401
# %%
def cli_main(
save_config_callback=None,
seed_everything_default=42,
trainer_defaults: dict = {
"accelerator": "auto",
"devices": "auto",
"strategy": "ddp",
"precision": "bf16-mixed",
"log_every_n_steps": 1,
"max_epochs": 100,
"accumulate_grad_batches": 5,
"callbacks": [
ModelCheckpoint(
dirpath="checkpoints/",
auto_insert_metric_name=False,
filename="mae_epoch-{epoch:02d}_val-loss-{val/loss:.2f}",
monitor="val/loss",
mode="min",
save_last=True,
save_top_k=2,
save_weights_only=True,
verbose=True,
),
LearningRateMonitor(logging_interval="step"),
LogIntermediatePredictions(),
],
"logger": [WandbLogger(project="CLAY-v0", log_model=False)],
"plugins": [AsyncCheckpointIO()],
},
args: ArgsType = None,
):
"""
Command-line inteface to run CLAYModule with ClayDataModule.
"""
cli = LightningCLI(
model_class=CLAYModule,
datamodule_class=ClayDataModule,
save_config_callback=save_config_callback,
seed_everything_default=seed_everything_default,
trainer_defaults=trainer_defaults,
args=args,
)
return cli
# %%
if __name__ == "__main__":
cli_main()
print("Done!")