-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtrain.py
executable file
·83 lines (59 loc) · 2.74 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from model.fcos import FCOSDetector
import torch
from dataloader.VOC_dataset import VOCDataset
import math,time
from torch.utils.tensorboard import SummaryWriter
from model.config import DefaultConfig as cfg
train_dataset=VOCDataset(cfg.dataset_dir,resize_size=cfg.image_scale,split='train')
val_dataset=VOCDataset(cfg.dataset_dir,resize_size=[512,800],split='val')
model=FCOSDetector(mode="training").cuda()
if cfg.resume_path:
model.load_state_dict(torch.load(cfg.resume_path))
optimizer=torch.optim.Adam(model.parameters(),lr=1e-4)
BATCH_SIZE=cfg.BATCH_SIZE
EPOCHS=cfg.EPOCHS
WARMPUP_STEPS_RATIO=cfg.WARMPUP_STEPS_RATIO
train_loader=torch.utils.data.DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,collate_fn=train_dataset.collate_fn)
val_loader=torch.utils.data.DataLoader(val_dataset,batch_size=BATCH_SIZE,shuffle=True,collate_fn=val_dataset.collate_fn)
steps_per_epoch=len(train_dataset)//BATCH_SIZE
TOTAL_STEPS=steps_per_epoch*EPOCHS
WARMPUP_STEPS=TOTAL_STEPS*WARMPUP_STEPS_RATIO
GLOBAL_STEPS=cfg.GLOBAL_STEPS
LR_INIT=cfg.LR_INIT
LR_END=cfg.LR_END
writer=SummaryWriter(log_dir="./logs")
def lr_func():
if GLOBAL_STEPS<WARMPUP_STEPS:
lr=GLOBAL_STEPS/WARMPUP_STEPS*LR_INIT
else:
lr=LR_END+0.5*(LR_INIT-LR_END)*(
(1+math.cos((GLOBAL_STEPS-WARMPUP_STEPS)/(TOTAL_STEPS-WARMPUP_STEPS)*math.pi))
)
return float(lr)
model.train()
for epoch in range(EPOCHS):
for epoch_step,data in enumerate(train_loader):
batch_imgs,batch_boxes,batch_classes=data
batch_imgs=batch_imgs.cuda()
batch_boxes=batch_boxes.cuda()
batch_classes=batch_classes.cuda()
lr=lr_func()
for param in optimizer.param_groups:
param['lr']=lr
start_time=time.time()
optimizer.zero_grad()
losses=model([batch_imgs,batch_boxes,batch_classes])
loss=losses[-1]
loss.backward()
optimizer.step()
end_time=time.time()
cost_time=int((end_time-start_time)*1000)
print("global_steps:%d epoch:%d steps:%d/%d cls_loss:%.4f cnt_loss:%.4f reg_loss:%.4f cost_time:%dms lr=%.4e"%\
(GLOBAL_STEPS,epoch+1,epoch_step+1,steps_per_epoch,losses[0],losses[1],losses[2],cost_time,lr))
writer.add_scalar("loss/cls_loss",losses[0],global_step=GLOBAL_STEPS)
writer.add_scalar("loss/cnt_loss",losses[1],global_step=GLOBAL_STEPS)
writer.add_scalar("loss/reg_loss",losses[2],global_step=GLOBAL_STEPS)
writer.add_scalar("lr",lr,global_step=GLOBAL_STEPS)
GLOBAL_STEPS+=1
if (epoch+1)%10 == 0:
torch.save(model.state_dict(),"./FCOSMASK_epoch%d_loss%.4f.pth"%(epoch+1,loss.item()))