From c62ca80c67eb6b80077ad6d088b974d29004d35a Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Sat, 22 Feb 2025 20:17:36 +0100 Subject: [PATCH] fix: update tests --- src/snake/core/phantom/static.py | 3 ++- src/snake/core/transform.py | 31 ++++++++++++++++++++++++------- tests/test_phantom.py | 7 +++++-- 3 files changed, 31 insertions(+), 10 deletions(-) diff --git a/src/snake/core/phantom/static.py b/src/snake/core/phantom/static.py index 1eb65ab..13d2cef 100644 --- a/src/snake/core/phantom/static.py +++ b/src/snake/core/phantom/static.py @@ -418,6 +418,7 @@ def contrast( sim_conf: SimConfig | None = None, resample: bool = True, aggregate: bool = True, + use_gpu: bool = True, ) -> NDArray[np.float32]: """Compute the contrast of the phantom for a given sequence. @@ -443,7 +444,7 @@ def contrast( raise ValueError("sim_conf must be provided for resampling.") affine = sim_conf.fov.affine shape = sim_conf.fov.shape - self = self.resample(affine, shape, use_gpu=True) + self = self.resample(affine, shape, use_gpu=use_gpu) if sim_conf is not None: TR = sim_conf.seq.TR_eff # Here we use the effective TR. diff --git a/src/snake/core/transform.py b/src/snake/core/transform.py index 032585c..97fad46 100644 --- a/src/snake/core/transform.py +++ b/src/snake/core/transform.py @@ -37,17 +37,26 @@ def _validate_gpu_affine(use_gpu: bool = True) -> tuple[bool, Callable, ModuleTy raise ImportError from exc def affine_transform( - x: NDArray, *args: Any, output_shape: ThreeInts, **kwargs: Any + x: NDArray, + *args: Any, + output_shape: ThreeInts, + output: NDArray[np.float32] = None, + **kwargs: Any, ) -> NDArray: - output = xp.zeros(output_shape, dtype=x.dtype) - return cu_affine_transform( + output_gpu = xp.zeros(output_shape, dtype=x.dtype) + cu_affine_transform( x, *args, output_shape=output_shape, - output=output, + output=output_gpu, **kwargs, texture_memory=x.dtype == xp.float32, - ).get() + ) + if output is not None: + xp.copyto(output, output_gpu) + return output + else: + return output_gpu.get() except ImportError: use_gpu = False if not use_gpu: @@ -63,6 +72,7 @@ def apply_affine( old_affine: NDArray[np.float32], new_affine: NDArray[np.float32], new_shape: ThreeInts, + output: NDArray[np.float32] = None, transform_affine: NDArray[np.float32] = None, use_gpu: bool = True, ) -> NDArray[np.float32]: @@ -81,6 +91,9 @@ def apply_affine( transform_affine : NDArray, optional Transformation affine, by default None use_gpu : bool, optional + Try to use GPU, by default True + output: NDArray, optional + Output array, by default None Returns ------- @@ -92,7 +105,9 @@ def apply_affine( transform_affine = effective_affine(new_affine, old_affine) transform_affine = xp.asarray(transform_affine, dtype=xp.float32) data = xp.asarray(data) - new_data = affine_transform(data, transform_affine, output_shape=new_shape) + new_data = affine_transform( + data, transform_affine, output_shape=new_shape, output=output + ) return new_data @@ -141,7 +156,9 @@ def apply_affine4d( if not use_gpu: run_parallel( - apply_affine, + lambda x, out, *args, **kwargs: apply_affine( + x, *args, output=out, **kwargs + ), data, new_array, old_affine=old_affine, diff --git a/tests/test_phantom.py b/tests/test_phantom.py index df15e98..0ae53c5 100644 --- a/tests/test_phantom.py +++ b/tests/test_phantom.py @@ -71,8 +71,11 @@ def test_mrd(phantom: Phantom, tmpdir: Path): @parametrize_with_cases("phantom", cases=CasesPhantom) -def test_contrast(phantom, sim_config): +@parametrize("use_gpu", [True, False]) +def test_contrast(phantom, sim_config, use_gpu): """Test that the phantom can be used in a simulation.""" - contrast = phantom.contrast(sim_conf=sim_config) + contrast = phantom.contrast( + sim_conf=sim_config, + ) # FIXME: This is not the correct way to test the contrast assert contrast.shape == sim_config.fov.shape