diff --git a/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py b/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py index ba30b9a3e..3ee0bb18d 100644 --- a/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py +++ b/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py @@ -582,7 +582,7 @@ def make_tasks_with_checkerboarding( # pylint: disable=too-many-locals, too-man next_chunk_id = task_idxs[-1].chunk_id + self.l0_chunks_per_task # get offsets in terms of index inds - red_chunk_offsets = tuple(itertools.product((0, 1), (0, 1), (0, 1))) + red_chunk_offsets = tuple(itertools.product((0, 1, 2), (0, 1, 2), (0, 1, 2))) # `set` to handle cases where the shape is 1 in some dimensions red_ind_offsets = set( offset[2] * red_shape[0] * red_shape[1] + offset[1] * red_shape[0] + offset[0] @@ -612,39 +612,41 @@ def make_tasks_with_checkerboarding( # pylint: disable=too-many-locals, too-man tasks += tasks_split red_ind = 0 - last_y_rollover = 0 + last_x_rollover = 0 last_xy_rollover = 0 - for task_idx in task_idxs: - if not task_idx.intersects(red_chunks[red_ind]): - # test rollover in Z and Y - if task_idx.intersects(red_chunks[last_xy_rollover]): - red_ind = last_xy_rollover - elif task_idx.intersects(red_chunks[last_y_rollover]): - red_ind = last_y_rollover - # all other cases - else: - try: - while not task_idx.intersects(red_chunks[red_ind]): - red_ind += 1 - if red_ind % (red_shape[0] * red_shape[1]) == 0: - last_xy_rollover = red_ind - if red_ind % red_shape[0] == 0: - last_y_rollover = red_ind - # This case catches the case where the chunk is entirely outside - # any reduction chunk; this can happen if, for instance, - # roi_crop_pad is set to [0, 0, 1] and the processing_chunk_size - # is [X, X, 1]. - except IndexError as e: - raise ValueError( - f"The processing chunk `{task_idx.pformat()}` does not " - " correspond to any reduction chunk; please check the " - "`roi_crop_pad` and the `processing_chunk_size`." - ) from e + for i, task_idx in enumerate(task_idxs): + # test rollover in X, and then XY - note that this + # also catches situations where the task_idx intersects + # the current reduction chunk but is hanging off the + # `left` edge + print(red_chunks[red_ind].shape, task_idx.shape) + if task_idx.intersects(red_chunks[last_xy_rollover]): + red_ind = last_xy_rollover + elif task_idx.intersects(red_chunks[last_x_rollover]): + red_ind = last_x_rollover + # all other cases + else: + try: + while not task_idx.intersects(red_chunks[red_ind]): + red_ind += 1 + if red_ind % (red_shape[0] * red_shape[1]) == 0: + last_xy_rollover = red_ind + if red_ind % red_shape[0] == 0: + last_x_rollover = red_ind + # This case catches the case where the chunk is entirely outside + # any reduction chunk; this can happen if, for instance, + # roi_crop_pad is set to [0, 0, 1] and the processing_chunk_size + # is [X, X, 1]. + except IndexError as e: + raise ValueError( + f"The processing chunk `{task_idx.pformat()}` does not " + " correspond to any reduction chunk; please check the " + "`roi_crop_pad` and the `processing_chunk_size`." + ) from e if task_idx.contained_in(red_chunks[red_ind]): red_chunks_task_idxs[red_ind].append(task_idx) red_chunks_temps[red_ind].append(dst_temp) else: - # check for 2 chunks in each dimension inds = [red_ind + red_ind_offset for red_ind_offset in red_ind_offsets] for i in inds: if i < len(red_chunks) and task_idx.intersects(red_chunks[i]):