diff --git a/pytorch3dunet/datasets/hdf5.py b/pytorch3dunet/datasets/hdf5.py index 205c164d..040adb85 100644 --- a/pytorch3dunet/datasets/hdf5.py +++ b/pytorch3dunet/datasets/hdf5.py @@ -87,11 +87,10 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r # compare patch and stride configuration patch_shape = slice_builder_config.get('patch_shape') stride_shape = slice_builder_config.get('stride_shape') - if patch_shape != stride_shape: - logger.warning(f'Patch shape and stride shape should be equal for optimal prediction performance,' - f'but found patch_shape: {patch_shape} and stride_shape: {stride_shape} in the config!' - f'Overriding stride_shape to match patch_shape!') - slice_builder_config['stride_shape'] = patch_shape + if sum(self.halo_shape) != 0 and patch_shape != stride_shape: + logger.warning(f'Found non-zero halo shape {self.halo_shape}. ' + f'In this case: patch shape and stride shape should be equal for optimal prediction ' + f'performance, but found patch_shape: {patch_shape} and stride_shape: {stride_shape}!') with h5py.File(file_path, 'r') as f: raw = f[raw_internal_path]