Skip to content

Commit

Permalink
add LM MLEM recon
Browse files Browse the repository at this point in the history
  • Loading branch information
Georg Schramm authored and Georg Schramm committed Nov 8, 2024
1 parent 350ba5a commit c2f703a
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 41 deletions.
93 changes: 55 additions & 38 deletions python/recon_block_scanner_listmode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def transform_BoxShape(
assert num_tofbins % 2 == 1, "Number of TOF bins must be odd"
# %%
# calculate the sensitivity image

print("Calculating sensitivity image")

# we loop through the symmetric group ID look up table to see which module pairs
# are in coincidence
Expand All @@ -185,7 +185,7 @@ def transform_BoxShape(
sgid = header.scanner.detection_efficiencies.module_pair_sgidlut[i, j]

if sgid >= 0:
print(i, j, sgid)
print(f"mod1 {i:03}, mod2 {j:03}, SGID {sgid:03}", end="\r")

start_det_el = det_element_center_list[i]
end_det_el = det_element_center_list[j]
Expand All @@ -200,8 +200,6 @@ def transform_BoxShape(
)
proj.tof_parameters = tof_params

# TODO: add TOF parameters

# get the module pair efficiencies - asumming that we only use 1 energy bin
module_pair_eff = (
header.scanner.detection_efficiencies.module_pair_efficiencies_vector[
Expand All @@ -212,61 +210,80 @@ def transform_BoxShape(
start_el_eff = xp.repeat(det_el_efficiencies[i], len(end_det_el), axis=0)
end_el_eff = xp.tile(det_el_efficiencies[j], (len(start_det_el)))

# TODO loop over TOF!
for tofbin in xp.arange(-(num_tofbins // 2), num_tofbins // 2 + 1):
# print(tofbin)
proj.event_tofbins = xp.full(
start_coords.shape[0], tofbin, dtype="int32"
)
sens_img += proj.adjoint(start_el_eff * end_el_eff * module_pair_eff)

print("")

# for some reason we have to divide the sens image by the number of TOF bins
# right now unclear why that is
sens_img /= num_tofbins
sens_img = res_model.adjoint(sens_img)

# %%
# read all coincidence events
print("Reading LM events")

if False:
num_prompts = 0
event_counter = 0
num_tof_bins = header.scanner.number_of_tof_bins()
num_prompts = 0
event_counter = 0
num_tof_bins = header.scanner.number_of_tof_bins()

xstart = []
xend = []
tof_bin = []
effs = []
xstart = []
xend = []
tof_bin = []
effs = []

for i_time_block, time_block in enumerate(reader.read_time_blocks()):
if isinstance(time_block, petsird.TimeBlock.EventTimeBlock):
num_prompts += len(time_block.value.prompt_events)
for i_time_block, time_block in enumerate(reader.read_time_blocks()):
if isinstance(time_block, petsird.TimeBlock.EventTimeBlock):
num_prompts += len(time_block.value.prompt_events)

for i_event, event in enumerate(time_block.value.prompt_events):
event_mods_and_els = get_module_and_element(
header.scanner.scanner_geometry, event.detector_ids
)
for i_event, event in enumerate(time_block.value.prompt_events):
event_mods_and_els = get_module_and_element(
header.scanner.scanner_geometry, event.detector_ids
)

event_start_coord = det_element_center_list[event_mods_and_els[0].module][
event_mods_and_els[0].el
]
xstart.append(event_start_coord)

event_start_coord = element_centers[
event_mods_and_els[0].module, event_mods_and_els[0].el
]
xstart.append(event_start_coord)
event_end_coord = det_element_center_list[event_mods_and_els[1].module][
event_mods_and_els[1].el
]
xend.append(event_end_coord)

event_end_coord = element_centers[
event_mods_and_els[1].module, event_mods_and_els[1].el
]
xend.append(event_end_coord)
# get the event efficiencies
effs.append(get_detection_efficiency(header.scanner, event))
# get the signed event TOF bin (0 is the central bin)
tof_bin.append(event.tof_idx - num_tof_bins // 2)

# get the event efficiencies
effs.append(get_detection_efficiency(header.scanner, event))
# get the signed event TOF bin (0 is the central bin)
tof_bin.append(event.tof_idx - num_tof_bins // 2)
event_counter += 1

reader.close()

xstart = xp.asarray(xstart, device=dev)
xend = xp.asarray(xend, device=dev)
effs = xp.asarray(effs, device=dev)
tof_bin = xp.asarray(tof_bin, device=dev)


# %%
# run a LM OSEM recon

event_counter += 1
num_iter = 50
recon = xp.ones(img_shape, dtype="float32")

reader.close()
proj = parallelproj.ListmodePETProjector(xstart, xend, img_shape, voxel_size)
proj.tof_parameters = tof_params
proj.event_tofbins = tof_bin

xstart = xp.asarray(xstart, device=dev)
xend = xp.asarray(xend, device=dev)
effs = xp.asarray(effs, device=dev)
tof_bin = xp.asarray(tof_bin, device=dev)
for i in range(num_iter):
print(f"it {(i +1):03} / {num_iter:03}")
lm_exp = effs * proj(res_model(recon))
print(lm_exp.min())
tmp = res_model(proj.adjoint(effs / lm_exp))
recon *= tmp / sens_img
9 changes: 6 additions & 3 deletions python/simulate_block_scanner_listmode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,12 @@ def module_pair_eff_from_sgd(i_sgd: int) -> float:
det_el_efficiencies = 0.2 + 2 * xp.astype(
xp.random.rand(scanner.num_modules, lor_desc.num_lorendpoints_per_block), "float32"
)
# multiply the det el eff. of the first module by 3 to introduce more variation
det_el_efficiencies[0, :] *= 3

# divide the det el eff. of the last module by 3 to introduce more variation
det_el_efficiencies[-1, :] /= 3

# simulate a few dead crystals
det_el_efficiencies[det_el_efficiencies < 0.21] = 0

Expand Down Expand Up @@ -257,7 +263,6 @@ def module_pair_eff_from_sgd(i_sgd: int) -> float:
# and normalization are ignored)

ones_back_tof = fwd_op.adjoint(xp.ones(fwd_op.out_shape, dtype=xp.float32, device=dev))
xp.save("ones_back_tof.npy", ones_back_tof)
print(ones_back_tof.shape)

# %%
Expand Down Expand Up @@ -421,7 +426,6 @@ def module_pair_eff_from_sgd(i_sgd: int) -> float:
tofbin=event_tof_bin,
)

xp.save("lm_back.npy", lm_back)
lm_back_non_tof = parallelproj.joseph3d_back(
xstart=scanner.get_lor_endpoints(event_start_block, event_start_el),
xend=scanner.get_lor_endpoints(event_end_block, event_end_el),
Expand All @@ -430,7 +434,6 @@ def module_pair_eff_from_sgd(i_sgd: int) -> float:
voxsize=proj.voxel_size,
img_fwd=xp.ones(num_events, dtype=xp.float32, device=dev),
)
xp.save("lm_back_non_tof.npy", lm_back_non_tof)

vi = pv.ThreeAxisViewer([histo_back, lm_back, histo_back - lm_back])

Expand Down

0 comments on commit c2f703a

Please sign in to comment.