Skip to content

Commit

Permalink
fix: fixes reduction chunk incrementation logic
Browse files Browse the repository at this point in the history
  • Loading branch information
dodamih authored and supersergiy committed Jan 15, 2025
1 parent c1cc05f commit 13fa91a
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def make_tasks_with_intermediaries( # pylint: disable=too-many-locals
tasks = self.make_tasks_without_checkerboarding(idx_chunks, dst_temp, op_kwargs)
return tasks, dst_temp

def make_tasks_with_checkerboarding( # pylint: disable=too-many-locals, too-many-branches
def make_tasks_with_checkerboarding( # pylint: disable=too-many-locals, too-many-branches, too-many-statements
self,
idx: VolumetricIndex,
red_chunks: List[VolumetricIndex],
Expand Down Expand Up @@ -582,7 +582,7 @@ def make_tasks_with_checkerboarding( # pylint: disable=too-many-locals, too-man
next_chunk_id = task_idxs[-1].chunk_id + self.l0_chunks_per_task

# get offsets in terms of index inds
red_chunk_offsets = list(itertools.product([0, 1, 2], [0, 1, 2], [0, 1, 2]))
red_chunk_offsets = tuple(itertools.product((0, 1), (0, 1), (0, 1)))
# `set` to handle cases where the shape is 1 in some dimensions
red_ind_offsets = set(
offset[2] * red_shape[0] * red_shape[1] + offset[1] * red_shape[0] + offset[0]
Expand Down Expand Up @@ -612,28 +612,28 @@ def make_tasks_with_checkerboarding( # pylint: disable=too-many-locals, too-man
tasks += tasks_split

red_ind = 0

last_y_rollover = 0
last_xy_rollover = 0
for task_idx in task_idxs:
if not task_idx.intersects(red_chunks[red_ind]):
# roll over in Z
if red_ind + 1 - red_shape[0] * red_shape[1] >= 0 and task_idx.intersects(
red_chunks[red_ind + 1 - red_shape[0] * red_shape[1]]
):
red_ind = red_ind + 1 - red_shape[0] * red_shape[1]
# roll over in Y
elif red_ind + 1 - red_shape[0] >= 0 and task_idx.intersects(
red_chunks[red_ind + 1 - red_shape[0]]
):
red_ind = red_ind + 1 - red_shape[0]
# test rollover in Z and Y
if task_idx.intersects(red_chunks[last_xy_rollover]):
red_ind = last_xy_rollover
elif task_idx.intersects(red_chunks[last_y_rollover]):
red_ind = last_y_rollover
# all other cases
else:
try:
while not task_idx.intersects(red_chunks[red_ind]):
red_ind += 1
if red_ind % (red_shape[0] * red_shape[1]) == 0:
last_xy_rollover = red_ind
if red_ind % red_shape[0] == 0:
last_y_rollover = red_ind
# This case catches the case where the chunk is entirely outside
# any # reduction chunk; this can happen if, for instance,
# any reduction chunk; this can happen if, for instance,
# roi_crop_pad is set to [0, 0, 1] and the processing_chunk_size
# is [X, X, 1]
# is [X, X, 1].
except IndexError as e:
raise ValueError(
f"The processing chunk `{task_idx.pformat()}` does not "
Expand All @@ -644,7 +644,7 @@ def make_tasks_with_checkerboarding( # pylint: disable=too-many-locals, too-man
red_chunks_task_idxs[red_ind].append(task_idx)
red_chunks_temps[red_ind].append(dst_temp)
else:
# check for up to 3 chunks in each dimension
# check for 2 chunks in each dimension
inds = [red_ind + red_ind_offset for red_ind_offset in red_ind_offsets]
for i in inds:
if i < len(red_chunks) and task_idx.intersects(red_chunks[i]):
Expand Down

0 comments on commit 13fa91a

Please sign in to comment.