-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
133 lines (115 loc) · 4.02 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import fire
from pytorch_lightning import Trainer
from util import init_exp_folder, Args
from util import constants as C
from data import dataset_split
from lightning import (get_task,
load_task,
get_ckpt_callback,
get_early_stop_callback,
get_logger)
def train(dataset_folder='./',
save_dir="./sandbox",
exp_name="MultitaskBaseline-ResNet101",
model="ResNet101",
products="naip-rgb",
image_size=720,
crop_size=None,
augmentation="none",
gpus=1,
pretrained=True,
accelerator=None,
logger_type='test_tube',
gradient_clip_val=0.5,
max_epochs=20,
batch_size=8,
num_workers=2,
lr=0.02,
patience=10,
stochastic_weight_avg=True,
limit_train_batches=1.0,
tb_path="./sandbox/tb",
loss_fn="BCE",
weights_summary=None,
task_type="all"
):
"""
Run the training experiment.
Args:
save_dir: Path to save the checkpoints and logs
exp_name: Name of the experiment
model: Model name
gpus: int. (ie: 2 gpus)
OR list to specify which GPUs [0, 1] OR '0,1'
OR '-1' / -1 to use all available gpus
pretrained: Whether or not to use the pretrained model
accelerator: Distributed computing mode
logger_type: 'wandb' or 'test_tube'
gradient_clip_val: Clip value of gradient norm
limit_train_batches: Proportion of training data to use
max_epochs: Max number of epochs
batch_size: Batch size
patience: number of epochs with no improvement after
which training will be stopped.
stochastic_weight_avg: Whether to use stochastic weight averaging.
tb_path: Path to global tb folder
loss_fn: Loss function to use
weights_summary: Prints a summary of the weights when training begins.
Returns: None
"""
args = Args(locals())
if task_type == 'all':
args['num_classes'] = len(C.class_labels_list)
else:
if task_type not in C.class_labels_list:
raise Exception('Invalid task type.')
args['num_classes'] = 1
if products not in C.valid_products:
raise Exception('Invalid product type.')
task = get_task(task_type, args)
init_exp_folder(args)
trainer = Trainer(gpus=gpus,
accelerator=accelerator,
logger=get_logger(logger_type, save_dir, exp_name),
callbacks=[get_early_stop_callback(patience),
get_ckpt_callback(save_dir, exp_name)],
weights_save_path=os.path.join(save_dir, exp_name),
gradient_clip_val=gradient_clip_val,
limit_train_batches=limit_train_batches,
weights_summary=weights_summary,
stochastic_weight_avg=stochastic_weight_avg,
max_epochs=max_epochs)
trainer.fit(task)
def validate(ckpt_path,
gpus=1,
**kwargs):
"""
Run the testing experiment.
Args:
ckpt_path: Path for the experiment to load
gpus: int. (ie: 2 gpus)
OR list to specify which GPUs [0, 1] OR '0,1'
OR '-1' / -1 to use all available gpus
Returns: None
"""
task = load_task(ckpt_path, **kwargs)
trainer = Trainer(gpus=gpus)
trainer.validate(task)
def test(ckpt_path,
gpus=1,
**kwargs):
"""
Run the testing experiment.
Args:
ckpt_path: Path for the experiment to load
gpus: int. (ie: 2 gpus)
OR list to specify which GPUs [0, 1] OR '0,1'
OR '-1' / -1 to use all available gpus
Returns: None
"""
task = load_task(ckpt_path, **kwargs)
trainer = Trainer(gpus=gpus)
trainer.test(task)
if __name__ == "__main__":
fire.Fire()