Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

changed hardcoded batch size with per-batch #5

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 39 additions & 50 deletions fcrn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,8 @@ def forward(self, x):

class UpProject(nn.Module):

def __init__(self, in_channels, out_channels, batch_size):
def __init__(self, in_channels, out_channels):
super(UpProject, self).__init__()
self.batch_size = batch_size

self.conv1_1 = nn.Conv2d(in_channels, out_channels, 3)
self.conv1_2 = nn.Conv2d(in_channels, out_channels, (2, 3))
Expand All @@ -70,41 +69,42 @@ def __init__(self, in_channels, out_channels, batch_size):
self.bn2 = nn.BatchNorm2d(out_channels)

def forward(self, x):
batch_size = x.shape[0]
# b, 10, 8, 1024
out1_1 = self.conv1_1(nn.functional.pad(x, (1, 1, 1, 1)))
out1_2 = self.conv1_2(nn.functional.pad(x, (1, 1, 0, 1)))#right interleaving padding
#out1_2 = self.conv1_2(nn.functional.pad(x, (1, 1, 1, 0)))#author's interleaving pading in github
out1_3 = self.conv1_3(nn.functional.pad(x, (0, 1, 1, 1)))#right interleaving padding
#out1_3 = self.conv1_3(nn.functional.pad(x, (1, 0, 1, 1)))#author's interleaving pading in github
out1_4 = self.conv1_4(nn.functional.pad(x, (0, 1, 0, 1)))#right interleaving padding
#out1_4 = self.conv1_4(nn.functional.pad(x, (1, 0, 1, 0)))#author's interleaving pading in github
out1_2 = self.conv1_2(nn.functional.pad(x, (1, 1, 0, 1))) # right interleaving padding
# out1_2 = self.conv1_2(nn.functional.pad(x, (1, 1, 1, 0)))#author's interleaving pading in github
out1_3 = self.conv1_3(nn.functional.pad(x, (0, 1, 1, 1))) # right interleaving padding
# out1_3 = self.conv1_3(nn.functional.pad(x, (1, 0, 1, 1)))#author's interleaving pading in github
out1_4 = self.conv1_4(nn.functional.pad(x, (0, 1, 0, 1))) # right interleaving padding
# out1_4 = self.conv1_4(nn.functional.pad(x, (1, 0, 1, 0)))#author's interleaving pading in github

out2_1 = self.conv2_1(nn.functional.pad(x, (1, 1, 1, 1)))
out2_2 = self.conv2_2(nn.functional.pad(x, (1, 1, 0, 1)))#right interleaving padding
#out2_2 = self.conv2_2(nn.functional.pad(x, (1, 1, 1, 0)))#author's interleaving pading in github
out2_3 = self.conv2_3(nn.functional.pad(x, (0, 1, 1, 1)))#right interleaving padding
#out2_3 = self.conv2_3(nn.functional.pad(x, (1, 0, 1, 1)))#author's interleaving pading in github
out2_4 = self.conv2_4(nn.functional.pad(x, (0, 1, 0, 1)))#right interleaving padding
#out2_4 = self.conv2_4(nn.functional.pad(x, (1, 0, 1, 0)))#author's interleaving pading in github
out2_2 = self.conv2_2(nn.functional.pad(x, (1, 1, 0, 1))) # right interleaving padding
# out2_2 = self.conv2_2(nn.functional.pad(x, (1, 1, 1, 0)))#author's interleaving pading in github
out2_3 = self.conv2_3(nn.functional.pad(x, (0, 1, 1, 1))) # right interleaving padding
# out2_3 = self.conv2_3(nn.functional.pad(x, (1, 0, 1, 1)))#author's interleaving pading in github
out2_4 = self.conv2_4(nn.functional.pad(x, (0, 1, 0, 1))) # right interleaving padding
# out2_4 = self.conv2_4(nn.functional.pad(x, (1, 0, 1, 0)))#author's interleaving pading in github

height = out1_1.size()[2]
width = out1_1.size()[3]

out1_1_2 = torch.stack((out1_1, out1_2), dim=-3).permute(0, 1, 3, 4, 2).contiguous().view(
self.batch_size, -1, height, width * 2)
batch_size, -1, height, width * 2)
out1_3_4 = torch.stack((out1_3, out1_4), dim=-3).permute(0, 1, 3, 4, 2).contiguous().view(
self.batch_size, -1, height, width * 2)
batch_size, -1, height, width * 2)

out1_1234 = torch.stack((out1_1_2, out1_3_4), dim=-3).permute(0, 1, 3, 2, 4).contiguous().view(
self.batch_size, -1, height * 2, width * 2)
batch_size, -1, height * 2, width * 2)

out2_1_2 = torch.stack((out2_1, out2_2), dim=-3).permute(0, 1, 3, 4, 2).contiguous().view(
self.batch_size, -1, height, width * 2)
batch_size, -1, height, width * 2)
out2_3_4 = torch.stack((out2_3, out2_4), dim=-3).permute(0, 1, 3, 4, 2).contiguous().view(
self.batch_size, -1, height, width * 2)
batch_size, -1, height, width * 2)

out2_1234 = torch.stack((out2_1_2, out2_3_4), dim=-3).permute(0, 1, 3, 2, 4).contiguous().view(
self.batch_size, -1, height * 2, width * 2)
batch_size, -1, height * 2, width * 2)

out1 = self.bn1_1(out1_1234)
out1 = self.relu(out1)
Expand All @@ -121,35 +121,34 @@ def forward(self, x):

class FCRN(nn.Module):

def __init__(self, batch_size):
def __init__(self):
super(FCRN, self).__init__()
self.inplanes = 64
self.batch_size = batch_size
# b, 304, 228, 3
# ResNet with out avrgpool & fc
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)# b, 152 114, 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) # b, 152 114, 64
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# b, 76, 57, 64
self.layer1 = self._make_layer(Bottleneck, 64, 3) #b, 76, 57, 256
self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2)# b, 38, 29, 512
self.layer3 = self._make_layer(Bottleneck, 256, 6, stride=2)# b, 19, 15, 1024
self.layer4 = self._make_layer(Bottleneck, 512, 3, stride=2)# b, 10, 8, 2048
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # b, 76, 57, 64
self.layer1 = self._make_layer(Bottleneck, 64, 3) # b, 76, 57, 256
self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2) # b, 38, 29, 512
self.layer3 = self._make_layer(Bottleneck, 256, 6, stride=2) # b, 19, 15, 1024
self.layer4 = self._make_layer(Bottleneck, 512, 3, stride=2) # b, 10, 8, 2048

# Up-Conv layers
self.conv2 = nn.Conv2d(2048, 1024, kernel_size=1, bias=False)# b, 10, 8, 1024
self.conv2 = nn.Conv2d(2048, 1024, kernel_size=1, bias=False) # b, 10, 8, 1024
self.bn2 = nn.BatchNorm2d(1024)

self.up1 = self._make_upproj_layer(UpProject, 1024, 512, self.batch_size)
self.up2 = self._make_upproj_layer(UpProject, 512, 256, self.batch_size)
self.up3 = self._make_upproj_layer(UpProject, 256, 128, self.batch_size)
self.up4 = self._make_upproj_layer(UpProject, 128, 64, self.batch_size)
self.up1 = self._make_upproj_layer(UpProject, 1024, 512)
self.up2 = self._make_upproj_layer(UpProject, 512, 256)
self.up3 = self._make_upproj_layer(UpProject, 256, 128)
self.up4 = self._make_upproj_layer(UpProject, 128, 64)

self.drop = nn.Dropout2d()

self.conv3 = nn.Conv2d(64, 1, 3, padding=1)

self.upsample = nn.Upsample((228, 304), mode='bilinear')
self.upsample = nn.Upsample((228, 304), mode='bilinear', align_corners=False)

# initialize
initialize = False
Expand Down Expand Up @@ -179,8 +178,8 @@ def _make_layer(self, block, planes, blocks, stride=1):

return nn.Sequential(*layers)

def _make_upproj_layer(self, block, in_channels, out_channels, batch_size):
return block(in_channels, out_channels, batch_size)
def _make_upproj_layer(self, block, in_channels, out_channels):
return block(in_channels, out_channels)

def forward(self, x):
x = self.conv1(x)
Expand Down Expand Up @@ -210,23 +209,13 @@ def forward(self, x):

return x


from torchsummary import summary

# 测试网络模型
if __name__ == '__main__':
batch_size = 1
batch_size = 3
net = FCRN(batch_size).cuda()
x = torch.zeros(batch_size, 3,304,228).cuda()
x = torch.zeros(batch_size, 3, 304, 228).cuda()
print(net(x).size())
summary(net, (3, 304, 228))












1 change: 0 additions & 1 deletion model/readme.txt

This file was deleted.

16 changes: 8 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fcrn import FCRN
from torch.autograd import Variable
from weights import load_weights
from utils import load_split, loss_mse, loss_huber
from utils import load_split, loss_mse, loss_berhu
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plot
Expand All @@ -25,16 +25,16 @@ def main():
# 1.Load data
train_lists, val_lists, test_lists = load_split()
print("Loading data...")
train_loader = torch.utils.data.DataLoader(NyuDepthLoader(data_path, train_lists),
batch_size=batch_size, shuffle=False, drop_last=True)
val_loader = torch.utils.data.DataLoader(NyuDepthLoader(data_path, val_lists),
batch_size=batch_size, shuffle=True, drop_last=True)
# train_loader = torch.utils.data.DataLoader(NyuDepthLoader(data_path, train_lists),
# batch_size=batch_size, shuffle=False, drop_last=True)
# val_loader = torch.utils.data.DataLoader(NyuDepthLoader(data_path, val_lists),
# batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(NyuDepthLoader(data_path, test_lists),
batch_size=batch_size, shuffle=True, drop_last=True)
print(train_loader)
# print(train_loader)
# 2.Load model
print("Loading model...")
model = FCRN(batch_size)
model = FCRN()
model.load_state_dict(load_weights(model, weights_file, dtype)) #加载官方参数,从tensorflow转过来
#加载训练模型
resume_from_file = False
Expand All @@ -56,7 +56,7 @@ def main():
# 自定义MSE
# loss_fn = loss_mse()
# 论文的loss,the reverse Huber
loss_fn = loss_huber()
loss_fn = loss_berhu()
print("loss_fn set...")

# 4.Optim
Expand Down
24 changes: 15 additions & 9 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,47 @@
import torch.nn as nn
from torch.autograd import Variable
import numpy as np


# 自定义损失函数
class loss_huber(nn.Module):
class loss_berhu(nn.Module):
def __init__(self):
super(loss_huber,self).__init__()
super(loss_berhu, self).__init__()

def forward(self, pred, truth):
c = pred.shape[1] #通道
h = pred.shape[2] #
w = pred.shape[3] #
c = pred.shape[1] # 通道
h = pred.shape[2] #
w = pred.shape[3] #
pred = pred.view(-1, c * h * w)
truth = truth.view(-1, c * h * w)
# 根据当前batch所有像素计算阈值
t = 0.2 * torch.max(torch.abs(pred - truth))
# 计算L1范数
l1 = torch.mean(torch.mean(torch.abs(pred - truth), 1), 0)
# 计算论文中的L2
l2 = torch.mean(torch.mean(((pred - truth)**2 + t**2) / t / 2, 1), 0)
l2 = torch.mean(torch.mean(((pred - truth) ** 2 + t ** 2) / t / 2, 1), 0)

if l1 > t:
return l2
else:
return l1


class loss_mse(nn.Module):
def __init__(self):
super(loss_mse, self).__init__()

def forward(self, pred, truth):
c = pred.shape[1]
h = pred.shape[2]
w = pred.shape[3]
pred = pred.view(-1, c * h * w)
truth = truth.view(-1, c * h * w)
return torch.mean(torch.mean((pred - truth), 1)**2, 0)
return torch.mean(torch.mean((pred - truth), 1) ** 2, 0)


if __name__ == '__main__':
loss = loss_huber()
loss = loss_berhu()
x = torch.zeros(2, 1, 2, 2)
y = torch.ones(2, 1, 2, 2)
c = x.shape[1]
Expand Down Expand Up @@ -79,6 +84,7 @@ def load_split():

return train_lists, val_lists, test_lists


# 测试网络
def validate(model, val_loader, loss_fn, dtype):
# validate
Expand Down Expand Up @@ -112,4 +118,4 @@ def validate(model, val_loader, loss_fn, dtype):
num_samples += 1

err = float(loss_local) / num_samples
print('val_error: %f' % err)
print('val_error: %f' % err)
Loading