Skip to content

Commit

Permalink
Network: optional rescale image input domain
Browse files Browse the repository at this point in the history
Past experiments have never shown this to improve results, but since
everyone else does it, optionally include it.

Do this in the network rather than the preprocessing for efficiency and
easy interoperability of trained networks.
  • Loading branch information
aschampion committed Oct 18, 2017
1 parent abb618e commit 595409e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
3 changes: 3 additions & 0 deletions diluvian/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ class NetworkConfig(BaseConfig):
inputs and outputs. Data is assumed to be ZYX row-major, but old
versions of diluvian used XYZ, so this is necessary to load old
networks.
rescale_image : bool
If true, rescale the input image intensity from [0, 1) to [-1, 1).
num_modules : int
Number of convolution modules to use, each module consisting of a skip
link in parallel with ``num_layers_per_module`` convolution layers.
Expand Down Expand Up @@ -176,6 +178,7 @@ class NetworkConfig(BaseConfig):
def __init__(self, settings):
self.factory = str(settings.get('factory'))
self.transpose = bool(settings.get('transpose', False))
self.rescale_image = bool(settings.get('rescale_image', False))
self.num_modules = int(settings.get('num_modules', 8))
self.num_layers_per_module = int(settings.get('num_layers_per_module', 2))
self.convolution_dim = np.array(settings.get('convolution_dim', [3, 3, 3]))
Expand Down
13 changes: 11 additions & 2 deletions diluvian/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Cropping3D,
Dropout,
Input,
Lambda,
Permute,
)
from keras.layers.merge import (
Expand All @@ -34,8 +35,12 @@ def make_flood_fill_network(input_fov_shape, output_fov_shape, network_config):
raise ValueError('ResNet implementation only supports same padding.')

image_input = Input(shape=tuple(input_fov_shape) + (1,), dtype='float32', name='image_input')
if network_config.rescale_image:
ffn = Lambda(lambda x: (x - 0.5) * 2.0)(image_input)
else:
ffn = image_input
mask_input = Input(shape=tuple(input_fov_shape) + (1,), dtype='float32', name='mask_input')
ffn = concatenate([image_input, mask_input])
ffn = concatenate([ffn, mask_input])

# Convolve and activate before beginning the skip connection modules,
# as discussed in the Appendix of He et al 2016.
Expand Down Expand Up @@ -109,8 +114,12 @@ def make_flood_fill_unet(input_fov_shape, output_fov_shape, network_config):
"""Construct a U-net flood filling network.
"""
image_input = Input(shape=tuple(input_fov_shape) + (1,), dtype='float32', name='image_input')
if network_config.rescale_image:
ffn = Lambda(lambda x: (x - 0.5) * 2.0)(image_input)
else:
ffn = image_input
mask_input = Input(shape=tuple(input_fov_shape) + (1,), dtype='float32', name='mask_input')
ffn = concatenate([image_input, mask_input])
ffn = concatenate([ffn, mask_input])

# Note that since the Keras 2 upgrade strangely models with depth > 3 are
# rejected by TF.
Expand Down

0 comments on commit 595409e

Please sign in to comment.