Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A couple of enhancements for MPRAGE #3

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions src/torchsim/models/mpnrage.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def set_sequence(
nshots: int,
flip: float,
TR: float,
MPRAGE_TR: float | None = None,
num_inversions: int = 1,
TI: float = 0.0,
slice_prof: float | npt.ArrayLike = 1.0,
):
Expand All @@ -96,6 +98,10 @@ def set_sequence(
TI : float, optional
Inversion time in milliseconds.
The default is ``0.0``.
MPRAGE_TR : float default is None
Repetition time in milliseconds for the whole inversion block.
num_inversions : int, optional
Number of inversion pulses, default is ``1``.
slice_prof : float | npt.ArrayLike, optional
Flip angle scaling along slice profile.
The default is ``1.0``.
Expand All @@ -106,6 +112,8 @@ def set_sequence(
self.sequence.TR = TR * 1e-3 # ms -> s
self.sequence.TI = TI * 1e-3 # ms -> s
self.sequence.slice_prof = slice_prof
self.sequence.num_inversions = num_inversions
self.sequence.MPRAGE_TR = MPRAGE_TR * 1e-3 if MPRAGE_TR is not None else None

@staticmethod
def _engine(
Expand All @@ -118,6 +126,8 @@ def _engine(
B1: float | npt.ArrayLike = 1.0,
inv_efficiency: float | npt.ArrayLike = 1.0,
slice_prof: float | npt.ArrayLike = 1.0,
num_inversions: int = 1,
MPRAGE_TR: float = None,
):
# Prepare relaxation parameters
R1 = 1e3 / T1
Expand All @@ -136,26 +146,29 @@ def _engine(

# Prepare relaxation operator for sequence loop
E1, rE1 = epg.longitudinal_relaxation_op(R1, TR)

if MPRAGE_TR is not None:
mprageE1, mpragerE1 = epg.longitudinal_relaxation_op(R1, MPRAGE_TR)
# Initialize signal
signal = []
for i in range(num_inversions):
# Apply inversion
states = epg.adiabatic_inversion(states, inv_efficiency)
states = epg.longitudinal_relaxation(states, E1inv, rE1inv)
states = epg.spoil(states)

# Apply inversion
states = epg.adiabatic_inversion(states, inv_efficiency)
states = epg.longitudinal_relaxation(states, E1inv, rE1inv)
states = epg.spoil(states)

# Scan loop
for p in range(nshots):
# Scan loop
for p in range(nshots):

# Apply RF pulse
states = epg.rf_pulse(states, RF)
# Apply RF pulse
states = epg.rf_pulse(states, RF)

# Record signal
signal.append(epg.get_signal(states))
# Record signal
signal.append(epg.get_signal(states))

# Evolve
states = epg.longitudinal_relaxation(states, E1, rE1)
states = epg.spoil(states)
# Evolve
states = epg.longitudinal_relaxation(states, E1, rE1)
states = epg.spoil(states)
if MPRAGE_TR is not None:
epg.longitudinal_relaxation(states, mprageE1, mpragerE1)

return M0 * 1j * torch.stack(signal)
44 changes: 29 additions & 15 deletions src/torchsim/models/mprage.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def set_sequence(
flip: float,
TRspgr: float,
nshots: int | npt.ArrayLike,
TRmprage: float = None,
num_inversions: int = 1,
):
"""
Set sequence parameters for the SPGR model.
Expand All @@ -88,13 +90,15 @@ def set_sequence(
Flip angle train in degrees.
TRspgr : float
Repetition time in milliseconds for each SPGR readout.
TRmprage : float
TRmprage : float default is None
Repetition time in milliseconds for the whole inversion block.
nshots : int | npt.ArrayLike
Number of SPGR readout within the inversion block of shape ``(npre, npost)``
If scalar, assume ``npre == npost == 0.5 * nshots``. Usually, this
is the number of slice encoding lines ``(nshots = nz / Rz)``,
i.e., the number of slices divided by the total acceleration factor along ``z``.
num_inversions : int, optional
Number of inversion pulses, default is ``1``.

"""
self.sequence.nshots = nshots
Expand All @@ -104,6 +108,10 @@ def set_sequence(
if nshots.numel() == 1:
nshots = torch.repeat_interleave(nshots // 2, 2)
self.sequence.nshots = nshots
self.sequence.TRmprage = TRmprage * 1e-3 # ms -> s
if TRmprage is None and num_inversions > 1:
raise ValueError("TRmprage must be provided for multiple inversions")
self.sequence.num_inversions = num_inversions

@staticmethod
def _engine(
Expand All @@ -115,6 +123,7 @@ def _engine(
nshots: int | npt.ArrayLike,
M0: float | npt.ArrayLike = 1.0,
inv_efficiency: float | npt.ArrayLike = 1.0,
num_inversions: int = 1,
):
R1 = 1e3 / T1

Expand All @@ -135,21 +144,26 @@ def _engine(

# Prepare relaxation operator for sequence loop
E1, rE1 = epg.longitudinal_relaxation_op(R1, TRspgr)
if TRmprage is not None:
mprageE1, mpragerE1 = epg.longitudinal_relaxation_op(R1, TRmprage)

# Apply inversion
states = epg.adiabatic_inversion(states, inv_efficiency)
states = epg.longitudinal_relaxation(states, E1inv, rE1inv)
states = epg.spoil(states)

signal = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to store magnetization for each spgr shot (e.g., to study signal modulation in k-space or for contrast-resolved reconstruction), it is better to use torchsim.models.MPnRAGEModel.

# Scan loop
for p in range(nshots_bef):

# Apply RF pulse
states = epg.rf_pulse(states, RF)

# Evolve
states = epg.longitudinal_relaxation(states, E1, rE1)
for i in range(num_inversions):
# Apply inversion
states = epg.adiabatic_inversion(states, inv_efficiency)
states = epg.longitudinal_relaxation(states, E1inv, rE1inv)
states = epg.spoil(states)


for p in range(nshots_bef*2):

# Apply RF pulse
states = epg.rf_pulse(states, RF)
# Evolve
states = epg.longitudinal_relaxation(states, E1, rE1)
signal.append(M0 * 1j * epg.get_signal(states))
states = epg.spoil(states)
if TRmprage is not None:
epg.longitudinal_relaxation(states, mprageE1, mpragerE1)
# Record signal
return M0 * 1j * epg.get_signal(states)
return torch.stack(signal)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move return M0 * 1j * torch.stack(signal) here?

Loading