Skip to content

Commit

Permalink
Add multi-GPU saved model kludge to fill
Browse files Browse the repository at this point in the history
  • Loading branch information
aschampion committed Feb 7, 2017
1 parent 5e0ab2d commit 6e13585
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
15 changes: 12 additions & 3 deletions diluvian.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def on_batch_end(self, batch, logs={}):
self.kludge['outputs'] = self.model.predict(self.kludge['inputs'])


def fill_region_from_model(model_file, volumes=None, bias=True, move_batch_size=1, max_moves=None):
def fill_region_from_model(model_file, volumes=None, bias=True, move_batch_size=1,
max_moves=None, multi_gpu_model_kludge=None):
if volumes is None:
raise ValueError('Volumes must be provided.')

Expand All @@ -108,7 +109,11 @@ def fill_region_from_model(model_file, volumes=None, bias=True, move_batch_size=

for region in regions:
region.bias_against_merge = bias
region.fill(model, verbose=True, move_batch_size=move_batch_size, max_moves=max_moves)
region.fill(model,
verbose=True,
move_batch_size=move_batch_size,
max_moves=max_moves,
multi_gpu_pad_kludge=multi_gpu_model_kludge)
viewer = region.get_viewer()
print viewer
s = raw_input("Press Enter to continue, a to export animation, q to quit...")
Expand Down Expand Up @@ -258,6 +263,9 @@ def cli():
help='Maximum number of fill moves to process in each prediction batch.')
fill_parser.add_argument('--max-moves', dest='max_moves', default=None, type=int,
help='Cancel filling after this many moves.')
fill_parser.add_argument('--multi-gpu-model-kludge', dest='multi_gpu_model_kludge', default=None, type=int,
help='Fix using a multi-GPU trained model that was not saved properly by '
'setting this to the number of training GPUs.')

args = parser.parse_args()

Expand All @@ -283,7 +291,8 @@ def cli():
volumes=volumes,
bias=args.bias,
move_batch_size=args.move_batch_size,
max_moves=args.max_moves)
max_moves=args.max_moves,
multi_gpu_model_kludge=args.multi_gpu_model_kludge)


if __name__ == "__main__":
Expand Down
13 changes: 12 additions & 1 deletion regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def get_next_block(self):
'target': target_block,
'position': next_pos}

def fill(self, model, verbose=False, move_batch_size=1, max_moves=None):
def fill(self, model, verbose=False, move_batch_size=1, max_moves=None, multi_gpu_pad_kludge=None):
moves = 0
if verbose:
pbar = tqdm(desc='Move queue')
Expand All @@ -133,6 +133,17 @@ def fill(self, model, verbose=False, move_batch_size=1, max_moves=None):
image_input = np.concatenate([pad_dims(b['image']) for b in batch_block_data])
mask_input = np.concatenate([pad_dims(b['mask']) for b in batch_block_data])

# For models generated with make_parallel that saved the parallel
# model, not the original model, some kludge is necessary so that
# the batch size is large enough to give each GPU in the parallel
# model an equal number of samples.
if multi_gpu_pad_kludge is not None and image_input.shape[0] % multi_gpu_pad_kludge != 0:
missing_samples = multi_gpu_pad_kludge - (image_input.shape[0] % multi_gpu_pad_kludge)
fill_dim = list(image_input.shape)
fill_dim[0] = missing_samples
image_input = np.concatenate((image_input, np.zeros(fill_dim, dtype=image_input.dtype)))
mask_input = np.concatenate((mask_input, np.zeros(fill_dim, dtype=mask_input.dtype)))

output = model.predict({'image_input': image_input,
'mask_input': mask_input})

Expand Down

0 comments on commit 6e13585

Please sign in to comment.