Skip to content

Commit

Permalink
Add multiple GPUs support
Browse files Browse the repository at this point in the history
  • Loading branch information
yujincheng08 authored and yujincheng committed Dec 19, 2018
1 parent b14b219 commit 9c63a75
Show file tree
Hide file tree
Showing 6 changed files with 376 additions and 438 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
venv
.idea
17 changes: 10 additions & 7 deletions SCNN-Tensorflow/lane-detection-model/config/global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,34 +19,37 @@
__C.TRAIN = edict()

# Set the shadownet training epochs
__C.TRAIN.EPOCHS = 90100 # 200010
__C.TRAIN.EPOCHS = 90100 # 200010
# Set the display step
__C.TRAIN.DISPLAY_STEP = 1
# Set the test display step during training process
__C.TRAIN.TEST_DISPLAY_STEP = 1000
# Set the momentum parameter of the optimizer
__C.TRAIN.MOMENTUM = 0.9
# Set the initial learning rate
__C.TRAIN.LEARNING_RATE = 0.01 # 0.0005
__C.TRAIN.LEARNING_RATE = 0.01 # 0.0005
# Set the GPU resource used during training process
__C.TRAIN.GPU_MEMORY_FRACTION = 0.85
# Set the GPU allow growth parameter during tensorflow training process
__C.TRAIN.TF_ALLOW_GROWTH = True
# Set the shadownet training batch size
__C.TRAIN.BATCH_SIZE = 8 # 4

__C.TRAIN.BATCH_SIZE = 4 # 4
# Set the shadownet validation batch size
__C.TRAIN.VAL_BATCH_SIZE = 8 # 4
__C.TRAIN.VAL_BATCH_SIZE = 4 # 4
# Set the learning rate decay steps
__C.TRAIN.LR_DECAY_STEPS = 210000
# Set the learning rate decay rate
__C.TRAIN.LR_DECAY_RATE = 0.1
# Set the class numbers
__C.TRAIN.CLASSES_NUMS = 2
# Set the image height
__C.TRAIN.IMG_HEIGHT = 384 # 256
__C.TRAIN.IMG_HEIGHT = 288 # 256
# Set the image width
__C.TRAIN.IMG_WIDTH = 608 # 512
__C.TRAIN.IMG_WIDTH = 800 # 512
# Set GPU number
__C.TRAIN.GPU_NUM = 1 # 8
# Set CPU thread number
__C.TRAIN.CPU_NUM = 1 #

# Test options
__C.TEST = edict()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,12 @@
"""
实现LaneNet的数据解析类
"""
import os.path as ops
import tensorflow as tf

import cv2
import numpy as np
from config import global_config

try:
from cv2 import cv2
except ImportError:
pass
CFG = global_config.cfg
VGG_MEAN = [103.939, 116.779, 123.68]


class DataSet(object):
Expand All @@ -26,91 +23,73 @@ class DataSet(object):

def __init__(self, dataset_info_file):
"""
:param dataset_info_file:
"""
self._gt_img_list, \
self._gt_label_instance_list, self._gt_label_existence_list = self._init_dataset(dataset_info_file)
self._random_dataset()
self._next_batch_loop_count = 0

def _init_dataset(self, dataset_info_file):
self._len = 0
self.dataset_info_file = dataset_info_file
self._img, self._label_instance, self._label_existence = self._init_dataset()

def __len__(self):
return self._len

def distorted_inputs(self):
pass

@staticmethod
def process_img(img_queue):
img_raw = tf.read_file(img_queue)
img_decoded = tf.image.decode_jpeg(img_raw, channels=3)
img_resized = tf.image.resize_images(img_decoded, [CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH],
method=tf.image.ResizeMethod.BICUBIC)
img_casted = tf.cast(img_resized, tf.float32)
return tf.subtract(img_casted, VGG_MEAN)

@staticmethod
def process_label_instance(label_instance_queue):
label_instance_raw = tf.read_file(label_instance_queue)
label_instance_decoded = tf.image.decode_png(label_instance_raw, channels=1)
label_instance_resized = tf.image.resize_images(label_instance_decoded,
[CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH],
method=tf.image.ResizeMethod.BICUBIC)
return tf.cast(label_instance_resized, tf.int32)

@staticmethod
def process_label_existence(label_existence_queue):
return tf.cast(label_existence_queue, tf.float32)

def _init_dataset(self):
"""
:param dataset_info_file:
:return:
"""
gt_img_list = []
gt_label_instance_list = []
gt_label_existence_list = []
if not tf.gfile.Exists(self.dataset_info_file):
raise ValueError('Failed to find file: ' + self.dataset_info_file)

assert ops.exists(dataset_info_file), '{:s} 不存在'.format(dataset_info_file)
img_list = []
label_instance_list = []
label_existence_list = []

with open(dataset_info_file, 'r') as file:
with open(self.dataset_info_file, 'r') as file:
for _info in file:
info_tmp = _info.strip(' ').split()

gt_img_list.append(info_tmp[0][1:])
gt_label_instance_list.append(info_tmp[1][1:])
gt_label_existence_list.append([int(info_tmp[2]), int(info_tmp[3]), int(info_tmp[4]), int(info_tmp[5])])

return gt_img_list, gt_label_instance_list, gt_label_existence_list

def _random_dataset(self):
"""
:return:
"""
assert len(self._gt_img_list) == len(self._gt_label_instance_list) == len(self._gt_label_existence_list)

random_idx = np.random.permutation(len(self._gt_img_list))
new_gt_img_list = []
new_gt_label_instance_list = []
new_gt_label_existence_list = []
img_list.append(info_tmp[0][1:])
label_instance_list.append(info_tmp[1][1:])
label_existence_list.append([int(info_tmp[2]), int(info_tmp[3]), int(info_tmp[4]), int(info_tmp[5])])

for index in random_idx:
new_gt_img_list.append(self._gt_img_list[index])
new_gt_label_instance_list.append(self._gt_label_instance_list[index])
new_gt_label_existence_list.append(self._gt_label_existence_list[index])
self._len = len(img_list)
# img_queue = tf.train.string_input_producer(img_list)
# label_instance_queue = tf.train.string_input_producer(label_instance_list)
with tf.name_scope('data_augmentation'):
image_tensor = tf.convert_to_tensor(img_list)
label_instance_tensor = tf.convert_to_tensor(label_instance_list)
label_existence_tensor = tf.convert_to_tensor(label_existence_list)
input_queue = tf.train.slice_input_producer([image_tensor, label_instance_tensor, label_existence_tensor])
img = self.process_img(input_queue[0])
label_instance = self.process_label_instance(input_queue[1])
label_existence = self.process_label_existence(input_queue[2])

self._gt_img_list = new_gt_img_list
self._gt_label_instance_list = new_gt_label_instance_list
self._gt_label_existence_list = new_gt_label_existence_list
return img, label_instance, label_existence

def next_batch(self, batch_size):
"""
:param batch_size:
:return:
"""
assert len(self._gt_label_instance_list) == len(self._gt_label_existence_list) \
== len(self._gt_img_list)

idx_start = batch_size * self._next_batch_loop_count
idx_end = batch_size * self._next_batch_loop_count + batch_size

if idx_end > len(self._gt_label_instance_list):
self._random_dataset()
self._next_batch_loop_count = 0
return self.next_batch(batch_size)
else:
gt_img_list = self._gt_img_list[idx_start:idx_end]
gt_label_instance_list = self._gt_label_instance_list[idx_start:idx_end]
gt_label_existence_list = self._gt_label_existence_list[idx_start:idx_end]

gt_imgs = []
gt_labels_instance = []
gt_labels_existence = []

for gt_img_path in gt_img_list:
gt_imgs.append(cv2.imread(gt_img_path, cv2.IMREAD_COLOR))

for gt_label_path in gt_label_instance_list:
label_img = cv2.imread(gt_label_path, cv2.IMREAD_UNCHANGED)
gt_labels_instance.append(label_img)

gt_labels_existence = gt_label_existence_list

self._next_batch_loop_count += 1
return gt_imgs, gt_labels_instance, gt_labels_existence

return tf.train.batch([self._img, self._label_instance, self._label_existence], batch_size=batch_size,
num_threads=CFG.TRAIN.CPU_NUM)
Loading

0 comments on commit 9c63a75

Please sign in to comment.