Skip to content

Commit

Permalink
fix: revert to 3 chunks, and check for rollover every time
Browse files Browse the repository at this point in the history
  • Loading branch information
dodamih committed Jan 25, 2025
1 parent 68e5fd1 commit 4cac4a8
Showing 1 changed file with 31 additions and 29 deletions.
60 changes: 31 additions & 29 deletions zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def make_tasks_with_checkerboarding( # pylint: disable=too-many-locals, too-man
next_chunk_id = task_idxs[-1].chunk_id + self.l0_chunks_per_task

# get offsets in terms of index inds
red_chunk_offsets = tuple(itertools.product((0, 1), (0, 1), (0, 1)))
red_chunk_offsets = tuple(itertools.product((0, 1, 2), (0, 1, 2), (0, 1, 2)))
# `set` to handle cases where the shape is 1 in some dimensions
red_ind_offsets = set(
offset[2] * red_shape[0] * red_shape[1] + offset[1] * red_shape[0] + offset[0]
Expand Down Expand Up @@ -612,39 +612,41 @@ def make_tasks_with_checkerboarding( # pylint: disable=too-many-locals, too-man
tasks += tasks_split

red_ind = 0
last_y_rollover = 0
last_x_rollover = 0
last_xy_rollover = 0
for task_idx in task_idxs:
if not task_idx.intersects(red_chunks[red_ind]):
# test rollover in Z and Y
if task_idx.intersects(red_chunks[last_xy_rollover]):
red_ind = last_xy_rollover
elif task_idx.intersects(red_chunks[last_y_rollover]):
red_ind = last_y_rollover
# all other cases
else:
try:
while not task_idx.intersects(red_chunks[red_ind]):
red_ind += 1
if red_ind % (red_shape[0] * red_shape[1]) == 0:
last_xy_rollover = red_ind
if red_ind % red_shape[0] == 0:
last_y_rollover = red_ind
# This case catches the case where the chunk is entirely outside
# any reduction chunk; this can happen if, for instance,
# roi_crop_pad is set to [0, 0, 1] and the processing_chunk_size
# is [X, X, 1].
except IndexError as e:
raise ValueError(
f"The processing chunk `{task_idx.pformat()}` does not "
" correspond to any reduction chunk; please check the "
"`roi_crop_pad` and the `processing_chunk_size`."
) from e
for i, task_idx in enumerate(task_idxs):
# test rollover in X, and then XY - note that this
# also catches situations where the task_idx intersects
# the current reduction chunk but is hanging off the
# `left` edge
print(red_chunks[red_ind].shape, task_idx.shape)
if task_idx.intersects(red_chunks[last_xy_rollover]):
red_ind = last_xy_rollover
elif task_idx.intersects(red_chunks[last_x_rollover]):
red_ind = last_x_rollover
# all other cases
else:
try:
while not task_idx.intersects(red_chunks[red_ind]):
red_ind += 1
if red_ind % (red_shape[0] * red_shape[1]) == 0:
last_xy_rollover = red_ind
if red_ind % red_shape[0] == 0:
last_x_rollover = red_ind
# This case catches the case where the chunk is entirely outside
# any reduction chunk; this can happen if, for instance,
# roi_crop_pad is set to [0, 0, 1] and the processing_chunk_size
# is [X, X, 1].
except IndexError as e:
raise ValueError(
f"The processing chunk `{task_idx.pformat()}` does not "
" correspond to any reduction chunk; please check the "
"`roi_crop_pad` and the `processing_chunk_size`."
) from e
if task_idx.contained_in(red_chunks[red_ind]):
red_chunks_task_idxs[red_ind].append(task_idx)
red_chunks_temps[red_ind].append(dst_temp)
else:
# check for 2 chunks in each dimension
inds = [red_ind + red_ind_offset for red_ind_offset in red_ind_offsets]
for i in inds:
if i < len(red_chunks) and task_idx.intersects(red_chunks[i]):
Expand Down

0 comments on commit 4cac4a8

Please sign in to comment.