From 3008a339eeeaaafbbf3f91a0e2508fef8760a85f Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Tue, 11 Feb 2025 14:44:51 +0100 Subject: [PATCH] fix: when n_coils=1 no smaps computation is required. --- src/snake/core/handlers/fov.py | 39 +++++++++++++++++--------------- src/snake/core/phantom/static.py | 11 ++++++--- 2 files changed, 29 insertions(+), 21 deletions(-) diff --git a/src/snake/core/handlers/fov.py b/src/snake/core/handlers/fov.py index c46be68..49e0264 100644 --- a/src/snake/core/handlers/fov.py +++ b/src/snake/core/handlers/fov.py @@ -131,13 +131,6 @@ def get_static(self, phantom: Phantom, sim_conf: SimConfig) -> Phantom: ), dtype=phantom.masks.dtype, ) - new_smaps = np.zeros( - ( - phantom.smaps.shape[0], - *tuple(round(size_vox[i] / zoom_factor[i]) for i in range(3)), - ), - dtype=phantom.smaps.dtype, - ) run_parallel( _apply_transform, @@ -149,17 +142,27 @@ def get_static(self, phantom: Phantom, sim_conf: SimConfig) -> Phantom: angles=self.angles, zoom_factor=zoom_factor, ) - - run_parallel( - _apply_transform, - phantom.smaps, - new_smaps, - parallel_axis=0, - center=center_vox, - size=size_vox, - angles=self.angles, - zoom_factor=zoom_factor, - ) + if phantom.smaps is not None: + return Phantom.from_masks(new_masks, sim_conf) + new_smaps = np.zeros( + ( + phantom.smaps.shape[0], + *tuple(round(size_vox[i] / zoom_factor[i]) for i in range(3)), + ), + dtype=phantom.smaps.dtype, + ) + run_parallel( + _apply_transform, + phantom.smaps, + new_smaps, + parallel_axis=0, + center=center_vox, + size=size_vox, + angles=self.angles, + zoom_factor=zoom_factor, + ) + else: + new_smaps = None # Create a new phantom with updated masks new_phantom = phantom.copy() diff --git a/src/snake/core/phantom/static.py b/src/snake/core/phantom/static.py index c8cf19d..ee3f9ac 100644 --- a/src/snake/core/phantom/static.py +++ b/src/snake/core/phantom/static.py @@ -151,14 +151,19 @@ def from_brainweb( ) tissues_mask = tissue_resized + smaps=None + if sim_conf.hardware.n_coils > 1: + smaps=get_smaps( + tissues_mask.shape[1:], + n_coils=sim_conf.hardware.n_coils, + ) + return cls( "brainweb", tissues_mask, labels=np.array([t[0] for t in tissues_list]), props=np.array([t[1:] for t in tissues_list]), - smaps=get_smaps( - tissues_mask.shape[1:], - n_coils=sim_conf.hardware.n_coils, + smaps=smaps, ), )