From 13fa91af6509c0db6386c8ca8006abcf1f650b45 Mon Sep 17 00:00:00 2001 From: Dodam Ih Date: Sun, 5 Jan 2025 09:54:42 -0800 Subject: [PATCH] fix: fixes reduction chunk incrementation logic --- .../common/volumetric_apply_flow.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py b/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py index fb88ba876..ba30b9a3e 100644 --- a/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py +++ b/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py @@ -515,7 +515,7 @@ def make_tasks_with_intermediaries( # pylint: disable=too-many-locals tasks = self.make_tasks_without_checkerboarding(idx_chunks, dst_temp, op_kwargs) return tasks, dst_temp - def make_tasks_with_checkerboarding( # pylint: disable=too-many-locals, too-many-branches + def make_tasks_with_checkerboarding( # pylint: disable=too-many-locals, too-many-branches, too-many-statements self, idx: VolumetricIndex, red_chunks: List[VolumetricIndex], @@ -582,7 +582,7 @@ def make_tasks_with_checkerboarding( # pylint: disable=too-many-locals, too-man next_chunk_id = task_idxs[-1].chunk_id + self.l0_chunks_per_task # get offsets in terms of index inds - red_chunk_offsets = list(itertools.product([0, 1, 2], [0, 1, 2], [0, 1, 2])) + red_chunk_offsets = tuple(itertools.product((0, 1), (0, 1), (0, 1))) # `set` to handle cases where the shape is 1 in some dimensions red_ind_offsets = set( offset[2] * red_shape[0] * red_shape[1] + offset[1] * red_shape[0] + offset[0] @@ -612,28 +612,28 @@ def make_tasks_with_checkerboarding( # pylint: disable=too-many-locals, too-man tasks += tasks_split red_ind = 0 - + last_y_rollover = 0 + last_xy_rollover = 0 for task_idx in task_idxs: if not task_idx.intersects(red_chunks[red_ind]): - # roll over in Z - if red_ind + 1 - red_shape[0] * red_shape[1] >= 0 and task_idx.intersects( - red_chunks[red_ind + 1 - red_shape[0] * red_shape[1]] - ): - red_ind = red_ind + 1 - red_shape[0] * red_shape[1] - # roll over in Y - elif red_ind + 1 - red_shape[0] >= 0 and task_idx.intersects( - red_chunks[red_ind + 1 - red_shape[0]] - ): - red_ind = red_ind + 1 - red_shape[0] + # test rollover in Z and Y + if task_idx.intersects(red_chunks[last_xy_rollover]): + red_ind = last_xy_rollover + elif task_idx.intersects(red_chunks[last_y_rollover]): + red_ind = last_y_rollover # all other cases else: try: while not task_idx.intersects(red_chunks[red_ind]): red_ind += 1 + if red_ind % (red_shape[0] * red_shape[1]) == 0: + last_xy_rollover = red_ind + if red_ind % red_shape[0] == 0: + last_y_rollover = red_ind # This case catches the case where the chunk is entirely outside - # any # reduction chunk; this can happen if, for instance, + # any reduction chunk; this can happen if, for instance, # roi_crop_pad is set to [0, 0, 1] and the processing_chunk_size - # is [X, X, 1] + # is [X, X, 1]. except IndexError as e: raise ValueError( f"The processing chunk `{task_idx.pformat()}` does not " @@ -644,7 +644,7 @@ def make_tasks_with_checkerboarding( # pylint: disable=too-many-locals, too-man red_chunks_task_idxs[red_ind].append(task_idx) red_chunks_temps[red_ind].append(dst_temp) else: - # check for up to 3 chunks in each dimension + # check for 2 chunks in each dimension inds = [red_ind + red_ind_offset for red_ind_offset in red_ind_offsets] for i in inds: if i < len(red_chunks) and task_idx.intersects(red_chunks[i]):