forked from snarb/Codes-for-Lane-Detection
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
31 changed files
with
82,431 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
BSD License | ||
|
||
For fb.resnet.torch software | ||
|
||
Copyright (c) 2016, Facebook, Inc. All rights reserved. | ||
|
||
Redistribution and use in source and binary forms, with or without modification, | ||
are permitted provided that the following conditions are met: | ||
|
||
* Redistributions of source code must retain the above copyright notice, this | ||
list of conditions and the following disclaimer. | ||
|
||
* Redistributions in binary form must reproduce the above copyright notice, | ||
this list of conditions and the following disclaimer in the documentation | ||
and/or other materials provided with the distribution. | ||
|
||
* Neither the name Facebook nor the names of its contributors may be used to | ||
endorse or promote products derived from this software without specific | ||
prior written permission. | ||
|
||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | ||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | ||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR | ||
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | ||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | ||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON | ||
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | ||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
local ParallelCriterion2, parent = torch.class('nn.ParallelCriterion2', 'nn.Criterion') | ||
|
||
function ParallelCriterion2:__init(repeatTarget) | ||
parent.__init(self) | ||
self.criterions = {} | ||
self.weights = {} | ||
self.gradInput = {} | ||
self.repeatTarget = repeatTarget | ||
end | ||
|
||
function ParallelCriterion2:add(criterion, weight) | ||
assert(criterion, 'no criterion provided') | ||
weight = weight or 1 | ||
table.insert(self.criterions, criterion) | ||
table.insert(self.weights, weight) | ||
return self | ||
end | ||
|
||
function ParallelCriterion2:updateOutput(input, target) | ||
self.output = 0 | ||
local output = {} | ||
for i,criterion in ipairs(self.criterions) do | ||
local target = self.repeatTarget and target or target[i] | ||
self.output = self.output + self.weights[i]*criterion:updateOutput(input[i],target) | ||
table.insert(output, self.weights[i]*criterion:updateOutput(input[i],target)) | ||
end | ||
return self.output, output | ||
end | ||
|
||
function ParallelCriterion2:updateGradInput(input, target) | ||
self.gradInput = nn.utils.recursiveResizeAs(self.gradInput, input) | ||
nn.utils.recursiveFill(self.gradInput, 0) | ||
for i,criterion in ipairs(self.criterions) do | ||
local target = self.repeatTarget and target or target[i] | ||
nn.utils.recursiveAdd(self.gradInput[i], self.weights[i], criterion:updateGradInput(input[i], target)) | ||
end | ||
return self.gradInput | ||
end | ||
|
||
function ParallelCriterion2:type(type, tensorCache) | ||
self.gradInput = {} | ||
return parent.type(self, type, tensorCache) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
### Before Start | ||
|
||
Please follow [list](./list) to put the BDD100K dataset (train, val, test) in the desired folder. The labels generated by our method are provided here ([train](https://drive.google.com/open?id=1wjoOQT6OJlLETz0ZYThBWjSt1Tgzn6_j) and [val](https://drive.google.com/open?id=1WaUjdgI4CMXkYFfi8Lz2rHiYYWLU5hav)). We'll call the directory that you cloned ENet-BDD100K-Torch as `$ENet_BDD100K_ROOT`. Note that this model only uses ENet as backbone, and if you use ENet-SAD, the performance will be better. | ||
|
||
### Testing | ||
1. Run test script | ||
```Shell | ||
cd $ENet_BDD100K_ROOT | ||
sh ./experiments/test_ENet.sh | ||
``` | ||
By now, you should be able to reproduce the result (Accuracy: 0.3656, mIoU: 16.02). | ||
|
||
### Training | ||
1. Training ENet model | ||
```Shell | ||
cd $ENet_BDD100K_ROOT | ||
sh ./experiments/train_ENet.sh | ||
``` | ||
The training process should start and trained models would be saved in `experiments/models/ENet-new/` by default. | ||
Then you can test the trained model following the Testing steps above. If your model position or name is changed, remember to set them to yours accordingly. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
-- | ||
-- Copyright (c) 2016, Facebook, Inc. | ||
-- All rights reserved. | ||
-- | ||
-- This source code is licensed under the BSD-style license found in the | ||
-- LICENSE file in the root directory of this source tree. An additional grant | ||
-- of patent rights can be found in the PATENTS file in the same directory. | ||
-- | ||
local checkpoint = {} | ||
|
||
local function deepCopy(tbl) | ||
-- creates a copy of a network with new modules and the same tensors | ||
local copy = {} | ||
for k, v in pairs(tbl) do | ||
if type(v) == 'table' then | ||
copy[k] = deepCopy(v) | ||
else | ||
copy[k] = v | ||
end | ||
end | ||
if torch.typename(tbl) then | ||
torch.setmetatable(copy, torch.typename(tbl)) | ||
end | ||
return copy | ||
end | ||
|
||
function checkpoint.latest(opt) | ||
if opt.resume == 'none' then | ||
return nil | ||
end | ||
|
||
local latestPath = paths.concat(opt.resume, 'latest.t7') | ||
if not paths.filep(latestPath) then | ||
return nil | ||
end | ||
|
||
print('=> Loading checkpoint ' .. latestPath) | ||
local latest = torch.load(latestPath) | ||
local optimState = torch.load(paths.concat(opt.resume, latest.optimFile)) | ||
|
||
return latest, optimState | ||
end | ||
|
||
function checkpoint.save(epoch, model, optimState, isBestModel, opt, iter, bestLoss) | ||
-- don't save the DataParallelTable for easier loading on other machines | ||
if torch.type(model) == 'nn.DataParallelTable' then | ||
model = model:get(1) | ||
end | ||
|
||
-- create a clean copy on the CPU without modifying the original network | ||
model = deepCopy(model):float():clearState() | ||
|
||
local modelFile = 'model_new.t7' | ||
local optimFile = 'optimState_new.t7' | ||
|
||
torch.save(paths.concat(opt.save, modelFile), model) | ||
torch.save(paths.concat(opt.save, optimFile), optimState) | ||
torch.save(paths.concat(opt.save, 'latest.t7'), { | ||
iter = iter, | ||
epoch = epoch, | ||
bestLoss = bestLoss, | ||
modelFile = modelFile, | ||
optimFile = optimFile, | ||
}) | ||
|
||
if isBestModel then | ||
torch.save(paths.concat(opt.save, 'model_best.t7'), model) | ||
end | ||
end | ||
|
||
return checkpoint |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
-- | ||
-- Copyright (c) 2016, Facebook, Inc. | ||
-- All rights reserved. | ||
-- | ||
-- This source code is licensed under the BSD-style license found in the | ||
-- LICENSE file in the root directory of this source tree. An additional grant | ||
-- of patent rights can be found in the PATENTS file in the same directory. | ||
-- | ||
-- Multi-threaded data loader | ||
-- | ||
|
||
local datasets = require 'datasets/init' | ||
local Threads = require 'threads' | ||
Threads.serialization('threads.sharedserialize') | ||
|
||
local M = {} | ||
local DataLoader = torch.class('resnet.DataLoader', M) | ||
|
||
function DataLoader.create(opt) | ||
-- The train and val loader | ||
local loaders = {} | ||
local data | ||
if opt.dataset == 'lane' then | ||
data = {'train', 'val'} | ||
elseif opt.dataset == 'laneTest' then | ||
data = {'val'} | ||
else | ||
cmd:error('unknown dataset: ' .. opt.dataset) | ||
end | ||
for i, split in ipairs(data) do | ||
local dataset = datasets.create(opt, split) | ||
print("data created") | ||
loaders[i] = M.DataLoader(dataset, opt, split) | ||
print("data loaded") | ||
end | ||
|
||
return table.unpack(loaders) | ||
end | ||
|
||
function DataLoader:__init(dataset, opt, split) | ||
local manualSeed = opt.manualSeed | ||
local function init() | ||
require('datasets/' .. opt.dataset) | ||
end | ||
local function main(idx) | ||
if manualSeed ~= 0 then | ||
torch.manualSeed(manualSeed + idx) | ||
end | ||
torch.setnumthreads(1) | ||
_G.dataset = dataset | ||
_G.preprocess = dataset:preprocess() | ||
_G.preprocess_aug = dataset:preprocess_aug() | ||
return dataset:size() | ||
end | ||
|
||
local threads, sizes = Threads(opt.nThreads, init, main) | ||
-- self.nCrops = (split == 'val' and opt.tenCrop) and 10 or 1 | ||
self.nCrops = 1 | ||
self.threads = threads | ||
self.__size = sizes[1][1] | ||
self.batchSize = math.floor(opt.batchSize / self.nCrops) | ||
self.split = split | ||
self.dataset = opt.dataset | ||
end | ||
|
||
function DataLoader:size() | ||
return math.ceil(self.__size / self.batchSize) | ||
end | ||
|
||
function DataLoader:run() | ||
local threads = self.threads | ||
local size, batchSize = self.__size, self.batchSize | ||
local dataset = self.dataset | ||
--if self.split == 'val' then | ||
--batchSize = torch.round(batchSize / 2) | ||
--end | ||
local perm | ||
if self.split == 'val' then | ||
perm = torch.Tensor(size) | ||
for i = 1, size do | ||
perm[i] = i | ||
end | ||
else | ||
perm = torch.randperm(size) | ||
end | ||
|
||
local idx, sample = 1, nil | ||
local function enqueue() | ||
while idx <= size and threads:acceptsjob() do | ||
local indices = perm:narrow(1, idx, math.min(batchSize, size - idx + 1)) | ||
threads:addjob( | ||
function(indices, nCrops) | ||
local sz = indices:size(1) | ||
local batch, segLabels, exists, imgpaths | ||
for i, idx in ipairs(indices:totable()) do | ||
local sample = _G.dataset:get(idx) | ||
local input, segLabel, exist | ||
if dataset=='laneTest' then | ||
input = _G.preprocess(sample.input) | ||
elseif dataset=='lane' then | ||
input, segLabel, exist = _G.preprocess_aug(sample.input, sample.segLabel, sample.exist) | ||
--print(segLabel:size()) | ||
segLabel:resize(segLabel:size(2),segLabel:size(3)) | ||
else | ||
cmd:error('unknown dataset: ' .. dataset) | ||
end | ||
if not batch then | ||
local imageSize = input:size():totable() | ||
local pathSize = sample.imgpath:size():totable() | ||
batch = torch.FloatTensor(sz, table.unpack(imageSize)) | ||
imgpaths = torch.CharTensor(sz, table.unpack(pathSize)) | ||
if dataset=='lane' then | ||
local labelSize = segLabel:size():totable() | ||
local existSize = exist:size():totable() | ||
segLabels = torch.FloatTensor(sz, table.unpack(labelSize)) | ||
exists = torch.FloatTensor(sz, table.unpack(existSize)) | ||
end | ||
end | ||
batch[i]:copy(input) | ||
imgpaths[i]:copy(sample.imgpath) | ||
if dataset=='lane' then | ||
segLabels[i]:copy(segLabel) | ||
exists[i]:copy(exist) | ||
end | ||
end | ||
local targets | ||
if dataset=='laneTest' then | ||
targets = nil | ||
elseif dataset=='lane' then | ||
targets = {segLabels, exists} | ||
else | ||
cmd:error('unknown dataset: ' .. dataset) | ||
end | ||
collectgarbage(); collectgarbage() | ||
|
||
return { | ||
input = batch, | ||
target = targets, | ||
imgpath = imgpaths, -- used in test | ||
} | ||
end, | ||
function(_sample_) | ||
sample = _sample_ | ||
end, | ||
indices, | ||
self.nCrops | ||
) | ||
idx = idx + batchSize | ||
end | ||
end | ||
|
||
local n = 0 | ||
local function loop() | ||
enqueue() | ||
if not threads:hasjob() then | ||
return nil | ||
end | ||
threads:dojob() | ||
if threads:haserror() then | ||
threads:synchronize() | ||
end | ||
enqueue() | ||
n = n + 1 | ||
return n, sample | ||
end | ||
return loop | ||
end | ||
|
||
return M.DataLoader |
Oops, something went wrong.