Skip to content

Commit

Permalink
Upload ENet-BDD100K-Torch
Browse files Browse the repository at this point in the history
  • Loading branch information
VernamCU committed Aug 15, 2019
1 parent 73d5cd5 commit 0fdf9cb
Show file tree
Hide file tree
Showing 31 changed files with 82,431 additions and 2 deletions.
30 changes: 30 additions & 0 deletions ENet-BDD100K-Torch/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
BSD License

For fb.resnet.torch software

Copyright (c) 2016, Facebook, Inc. All rights reserved.

Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

* Neither the name Facebook nor the names of its contributors may be used to
endorse or promote products derived from this software without specific
prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
43 changes: 43 additions & 0 deletions ENet-BDD100K-Torch/ParallelCriterion2.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
local ParallelCriterion2, parent = torch.class('nn.ParallelCriterion2', 'nn.Criterion')

function ParallelCriterion2:__init(repeatTarget)
parent.__init(self)
self.criterions = {}
self.weights = {}
self.gradInput = {}
self.repeatTarget = repeatTarget
end

function ParallelCriterion2:add(criterion, weight)
assert(criterion, 'no criterion provided')
weight = weight or 1
table.insert(self.criterions, criterion)
table.insert(self.weights, weight)
return self
end

function ParallelCriterion2:updateOutput(input, target)
self.output = 0
local output = {}
for i,criterion in ipairs(self.criterions) do
local target = self.repeatTarget and target or target[i]
self.output = self.output + self.weights[i]*criterion:updateOutput(input[i],target)
table.insert(output, self.weights[i]*criterion:updateOutput(input[i],target))
end
return self.output, output
end

function ParallelCriterion2:updateGradInput(input, target)
self.gradInput = nn.utils.recursiveResizeAs(self.gradInput, input)
nn.utils.recursiveFill(self.gradInput, 0)
for i,criterion in ipairs(self.criterions) do
local target = self.repeatTarget and target or target[i]
nn.utils.recursiveAdd(self.gradInput[i], self.weights[i], criterion:updateGradInput(input[i], target))
end
return self.gradInput
end

function ParallelCriterion2:type(type, tensorCache)
self.gradInput = {}
return parent.type(self, type, tensorCache)
end
21 changes: 21 additions & 0 deletions ENet-BDD100K-Torch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
### Before Start

Please follow [list](./list) to put the BDD100K dataset (train, val, test) in the desired folder. The labels generated by our method are provided here ([train](https://drive.google.com/open?id=1wjoOQT6OJlLETz0ZYThBWjSt1Tgzn6_j) and [val](https://drive.google.com/open?id=1WaUjdgI4CMXkYFfi8Lz2rHiYYWLU5hav)). We'll call the directory that you cloned ENet-BDD100K-Torch as `$ENet_BDD100K_ROOT`. Note that this model only uses ENet as backbone, and if you use ENet-SAD, the performance will be better.

### Testing
1. Run test script
```Shell
cd $ENet_BDD100K_ROOT
sh ./experiments/test_ENet.sh
```
By now, you should be able to reproduce the result (Accuracy: 0.3656, mIoU: 16.02).

### Training
1. Training ENet model
```Shell
cd $ENet_BDD100K_ROOT
sh ./experiments/train_ENet.sh
```
The training process should start and trained models would be saved in `experiments/models/ENet-new/` by default.
Then you can test the trained model following the Testing steps above. If your model position or name is changed, remember to set them to yours accordingly.

71 changes: 71 additions & 0 deletions ENet-BDD100K-Torch/checkpoints.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
--
-- Copyright (c) 2016, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
local checkpoint = {}

local function deepCopy(tbl)
-- creates a copy of a network with new modules and the same tensors
local copy = {}
for k, v in pairs(tbl) do
if type(v) == 'table' then
copy[k] = deepCopy(v)
else
copy[k] = v
end
end
if torch.typename(tbl) then
torch.setmetatable(copy, torch.typename(tbl))
end
return copy
end

function checkpoint.latest(opt)
if opt.resume == 'none' then
return nil
end

local latestPath = paths.concat(opt.resume, 'latest.t7')
if not paths.filep(latestPath) then
return nil
end

print('=> Loading checkpoint ' .. latestPath)
local latest = torch.load(latestPath)
local optimState = torch.load(paths.concat(opt.resume, latest.optimFile))

return latest, optimState
end

function checkpoint.save(epoch, model, optimState, isBestModel, opt, iter, bestLoss)
-- don't save the DataParallelTable for easier loading on other machines
if torch.type(model) == 'nn.DataParallelTable' then
model = model:get(1)
end

-- create a clean copy on the CPU without modifying the original network
model = deepCopy(model):float():clearState()

local modelFile = 'model_new.t7'
local optimFile = 'optimState_new.t7'

torch.save(paths.concat(opt.save, modelFile), model)
torch.save(paths.concat(opt.save, optimFile), optimState)
torch.save(paths.concat(opt.save, 'latest.t7'), {
iter = iter,
epoch = epoch,
bestLoss = bestLoss,
modelFile = modelFile,
optimFile = optimFile,
})

if isBestModel then
torch.save(paths.concat(opt.save, 'model_best.t7'), model)
end
end

return checkpoint
169 changes: 169 additions & 0 deletions ENet-BDD100K-Torch/dataloader.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
--
-- Copyright (c) 2016, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
-- Multi-threaded data loader
--

local datasets = require 'datasets/init'
local Threads = require 'threads'
Threads.serialization('threads.sharedserialize')

local M = {}
local DataLoader = torch.class('resnet.DataLoader', M)

function DataLoader.create(opt)
-- The train and val loader
local loaders = {}
local data
if opt.dataset == 'lane' then
data = {'train', 'val'}
elseif opt.dataset == 'laneTest' then
data = {'val'}
else
cmd:error('unknown dataset: ' .. opt.dataset)
end
for i, split in ipairs(data) do
local dataset = datasets.create(opt, split)
print("data created")
loaders[i] = M.DataLoader(dataset, opt, split)
print("data loaded")
end

return table.unpack(loaders)
end

function DataLoader:__init(dataset, opt, split)
local manualSeed = opt.manualSeed
local function init()
require('datasets/' .. opt.dataset)
end
local function main(idx)
if manualSeed ~= 0 then
torch.manualSeed(manualSeed + idx)
end
torch.setnumthreads(1)
_G.dataset = dataset
_G.preprocess = dataset:preprocess()
_G.preprocess_aug = dataset:preprocess_aug()
return dataset:size()
end

local threads, sizes = Threads(opt.nThreads, init, main)
-- self.nCrops = (split == 'val' and opt.tenCrop) and 10 or 1
self.nCrops = 1
self.threads = threads
self.__size = sizes[1][1]
self.batchSize = math.floor(opt.batchSize / self.nCrops)
self.split = split
self.dataset = opt.dataset
end

function DataLoader:size()
return math.ceil(self.__size / self.batchSize)
end

function DataLoader:run()
local threads = self.threads
local size, batchSize = self.__size, self.batchSize
local dataset = self.dataset
--if self.split == 'val' then
--batchSize = torch.round(batchSize / 2)
--end
local perm
if self.split == 'val' then
perm = torch.Tensor(size)
for i = 1, size do
perm[i] = i
end
else
perm = torch.randperm(size)
end

local idx, sample = 1, nil
local function enqueue()
while idx <= size and threads:acceptsjob() do
local indices = perm:narrow(1, idx, math.min(batchSize, size - idx + 1))
threads:addjob(
function(indices, nCrops)
local sz = indices:size(1)
local batch, segLabels, exists, imgpaths
for i, idx in ipairs(indices:totable()) do
local sample = _G.dataset:get(idx)
local input, segLabel, exist
if dataset=='laneTest' then
input = _G.preprocess(sample.input)
elseif dataset=='lane' then
input, segLabel, exist = _G.preprocess_aug(sample.input, sample.segLabel, sample.exist)
--print(segLabel:size())
segLabel:resize(segLabel:size(2),segLabel:size(3))
else
cmd:error('unknown dataset: ' .. dataset)
end
if not batch then
local imageSize = input:size():totable()
local pathSize = sample.imgpath:size():totable()
batch = torch.FloatTensor(sz, table.unpack(imageSize))
imgpaths = torch.CharTensor(sz, table.unpack(pathSize))
if dataset=='lane' then
local labelSize = segLabel:size():totable()
local existSize = exist:size():totable()
segLabels = torch.FloatTensor(sz, table.unpack(labelSize))
exists = torch.FloatTensor(sz, table.unpack(existSize))
end
end
batch[i]:copy(input)
imgpaths[i]:copy(sample.imgpath)
if dataset=='lane' then
segLabels[i]:copy(segLabel)
exists[i]:copy(exist)
end
end
local targets
if dataset=='laneTest' then
targets = nil
elseif dataset=='lane' then
targets = {segLabels, exists}
else
cmd:error('unknown dataset: ' .. dataset)
end
collectgarbage(); collectgarbage()

return {
input = batch,
target = targets,
imgpath = imgpaths, -- used in test
}
end,
function(_sample_)
sample = _sample_
end,
indices,
self.nCrops
)
idx = idx + batchSize
end
end

local n = 0
local function loop()
enqueue()
if not threads:hasjob() then
return nil
end
threads:dojob()
if threads:haserror() then
threads:synchronize()
end
enqueue()
n = n + 1
return n, sample
end
return loop
end

return M.DataLoader
Loading

0 comments on commit 0fdf9cb

Please sign in to comment.