Skip to content

Commit

Permalink
fix: update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
paquiteau committed Feb 22, 2025
1 parent db124a1 commit c62ca80
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 10 deletions.
3 changes: 2 additions & 1 deletion src/snake/core/phantom/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def contrast(
sim_conf: SimConfig | None = None,
resample: bool = True,
aggregate: bool = True,
use_gpu: bool = True,
) -> NDArray[np.float32]:
"""Compute the contrast of the phantom for a given sequence.
Expand All @@ -443,7 +444,7 @@ def contrast(
raise ValueError("sim_conf must be provided for resampling.")
affine = sim_conf.fov.affine
shape = sim_conf.fov.shape
self = self.resample(affine, shape, use_gpu=True)
self = self.resample(affine, shape, use_gpu=use_gpu)

if sim_conf is not None:
TR = sim_conf.seq.TR_eff # Here we use the effective TR.
Expand Down
31 changes: 24 additions & 7 deletions src/snake/core/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,26 @@ def _validate_gpu_affine(use_gpu: bool = True) -> tuple[bool, Callable, ModuleTy
raise ImportError from exc

def affine_transform(
x: NDArray, *args: Any, output_shape: ThreeInts, **kwargs: Any
x: NDArray,
*args: Any,
output_shape: ThreeInts,
output: NDArray[np.float32] = None,
**kwargs: Any,
) -> NDArray:
output = xp.zeros(output_shape, dtype=x.dtype)
return cu_affine_transform(
output_gpu = xp.zeros(output_shape, dtype=x.dtype)
cu_affine_transform(
x,
*args,
output_shape=output_shape,
output=output,
output=output_gpu,
**kwargs,
texture_memory=x.dtype == xp.float32,
).get()
)
if output is not None:
xp.copyto(output, output_gpu)
return output
else:
return output_gpu.get()
except ImportError:
use_gpu = False
if not use_gpu:
Expand All @@ -63,6 +72,7 @@ def apply_affine(
old_affine: NDArray[np.float32],
new_affine: NDArray[np.float32],
new_shape: ThreeInts,
output: NDArray[np.float32] = None,
transform_affine: NDArray[np.float32] = None,
use_gpu: bool = True,
) -> NDArray[np.float32]:
Expand All @@ -81,6 +91,9 @@ def apply_affine(
transform_affine : NDArray, optional
Transformation affine, by default None
use_gpu : bool, optional
Try to use GPU, by default True
output: NDArray, optional
Output array, by default None
Returns
-------
Expand All @@ -92,7 +105,9 @@ def apply_affine(
transform_affine = effective_affine(new_affine, old_affine)
transform_affine = xp.asarray(transform_affine, dtype=xp.float32)
data = xp.asarray(data)
new_data = affine_transform(data, transform_affine, output_shape=new_shape)
new_data = affine_transform(
data, transform_affine, output_shape=new_shape, output=output
)

return new_data

Expand Down Expand Up @@ -141,7 +156,9 @@ def apply_affine4d(

if not use_gpu:
run_parallel(
apply_affine,
lambda x, out, *args, **kwargs: apply_affine(
x, *args, output=out, **kwargs
),
data,
new_array,
old_affine=old_affine,
Expand Down
7 changes: 5 additions & 2 deletions tests/test_phantom.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,11 @@ def test_mrd(phantom: Phantom, tmpdir: Path):


@parametrize_with_cases("phantom", cases=CasesPhantom)
def test_contrast(phantom, sim_config):
@parametrize("use_gpu", [True, False])
def test_contrast(phantom, sim_config, use_gpu):
"""Test that the phantom can be used in a simulation."""
contrast = phantom.contrast(sim_conf=sim_config)
contrast = phantom.contrast(
sim_conf=sim_config,
)
# FIXME: This is not the correct way to test the contrast
assert contrast.shape == sim_config.fov.shape

0 comments on commit c62ca80

Please sign in to comment.