Skip to content

Commit

Permalink
Fix restricted sample indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
tskisner committed Nov 15, 2024
1 parent a4249c1 commit 147694b
Showing 1 changed file with 28 additions and 18 deletions.
46 changes: 28 additions & 18 deletions sotodlib/toast/ops/load_context_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,13 @@ def parse_metadata(axman, obs_meta, fp_cols, path_sep, det_axis, obs_base, fp_ba


def _local_process_dets(obs, rank):
"""Get the list of local detectors for a specific rank.
"""
"""Get the list of local detectors for a specific rank."""
# Full detector list
det_names = obs.all_detectors
# The range of full detector indices on this process.
full_indices = obs.dist.det_indices[rank]
# The same info, as a slice.
full_slc = slice(
full_indices.offset, full_indices.offset + full_indices.n_elem, 1
)
full_slc = slice(full_indices.offset, full_indices.offset + full_indices.n_elem, 1)
# The names of the detectors on this process.
return det_names[full_slc]

Expand Down Expand Up @@ -415,12 +412,13 @@ def _process_flag(mask, inbuf, outbuf, invert):
# The receiving detector names and their relative indices
recv_dets = proc_wafer_dets[receiver][wafer]
n_recv_det = recv_dets[1] - recv_dets[0]
recv_det_names = local_det_names[recv_dets[0]:recv_dets[1]]
recv_det_names = local_det_names[recv_dets[0] : recv_dets[1]]
recv_det_indices = {y: x for x, y in enumerate(recv_det_names)}

# Build the mapping of restricted indices to send buffer indices
restrict_to_send = {
restricted_indices[x]: y for x, y in recv_det_indices.items()
restricted_indices[x]: y
for x, y in recv_det_indices.items()
if x in restricted_indices
}

Expand All @@ -447,38 +445,50 @@ def _process_flag(mask, inbuf, outbuf, invert):
# If so, we construct a temporary buffer and build sample flags
# from the ranges.
if isinstance(axwafers[wafer][axfield], so3g.proj.RangesMatrix):
# Yes, flagged ranges
sdata = np.empty(
# Yes, flagged ranges. We may have restricted sample ranges
# for our flags, and so we initialize the full buffer to
# the invalid mask.
sdata = defaults.det_mask_invalid * np.ones(
flat_size,
dtype=np.uint8,
)
if flag_invert:
sdata[:] = flag_mask
for idet_ax, idet_send in restrict_to_send.items():
# Set this detector's values to the mask, and then
# we will "unflag" the specified ranges.
off = idet_send * obs.n_local_samples + ax_shift
sdata[off : off + restricted_samps] = flag_mask
for rg in axwafers[wafer][axfield][idet_ax].ranges():
sdata[off + rg[0] : off + rg[1]] = 0
else:
sdata[:] = 0
for idet_ax, idet_send in restrict_to_send.items():
# Set this detector's values to good, and then
# we will flag the specified ranges.
off = idet_send * obs.n_local_samples + ax_shift
sdata[off : off + restricted_samps] = 0
for rg in axwafers[wafer][axfield][idet_ax].ranges():
sdata[off + rg[0] : off + rg[1]] = flag_mask
else:
# Either normal sample flags or signal data
if is_flag:
sdata = flag_mask * np.ones(
# We may have restricted sample ranges for our flags, and so
# we initialize the full buffer to the invalid mask.
sdata = defaults.det_mask_invalid * np.ones(
flat_size,
dtype=axdtype,
)
else:
# Signal data is initialized to zero. Restricted / missing
# samples will be indicated by flags.
sdata = np.zeros(
flat_size,
dtype=axdtype,
)
for idet_ax, idet_send in restrict_to_send.items():
off = idet_send * obs.n_local_samples + ax_shift
sdata[off : off + obs.n_local_samples] = axwafers[wafer][axfield][idet_ax, :]
sdata[off : off + restricted_samps] = axwafers[wafer][axfield][
idet_ax, :
]

if receiver == rank:
# We just need to process the data locally
Expand All @@ -496,7 +506,8 @@ def _process_flag(mask, inbuf, outbuf, invert):
# Update per-detector flags
dflags = {
obs.local_detectors[recv_dets[0] + x]: defaults.det_mask_invalid
for x in range(n_recv_det) if det_flags[x] != 0
for x in range(n_recv_det)
if det_flags[x] != 0
}
obs.update_local_detector_flags(dflags)
else:
Expand Down Expand Up @@ -534,13 +545,12 @@ def _process_flag(mask, inbuf, outbuf, invert):
)
else:
# Just assign
obs.detdata[field][det_slc, :] = recv_data.reshape(
(n_recv_det, -1)
)
obs.detdata[field][det_slc, :] = recv_data.reshape((n_recv_det, -1))
# Update per-detector flags
dflags = {
obs.local_detectors[recv_dets[0] + x]: defaults.det_mask_invalid
for x in range(n_recv_det) if det_flags[x] != 0
for x in range(n_recv_det)
if det_flags[x] != 0
}
obs.update_local_detector_flags(dflags)
del recv_data
Expand Down

0 comments on commit 147694b

Please sign in to comment.