Skip to content

Commit

Permalink
Merge pull request #193 from mjo22/coordinate-system-rename
Browse files Browse the repository at this point in the history
Renaming that gives coordinate systems that are in pixels an “_in_pixels” ending
mjo22 authored Mar 25, 2024
2 parents 53258cf + bf50bfe commit 3b14a41
Showing 11 changed files with 118 additions and 89 deletions.
22 changes: 14 additions & 8 deletions src/cryojax/image/_edges.py
Original file line number Diff line number Diff line change
@@ -15,20 +15,23 @@
def crop_to_shape(
image_or_volume: Image,
shape: tuple[int, int],
) -> Inexact[Array, " *shape"]: ...
) -> Inexact[Array, " {shape[0]} {shape[1]}"]: ...


@overload
def crop_to_shape(
image_or_volume: Volume,
shape: tuple[int, int, int],
) -> Inexact[Array, " *shape"]: ...
) -> Inexact[Array, " {shape[0]} {shape[1]} {shape[2]}"]: ...


def crop_to_shape(
image_or_volume: Image | Volume,
shape: tuple[int, int] | tuple[int, int, int],
) -> Inexact[Array, " *shape"]:
) -> (
Inexact[Array, " {shape[0]} {shape[1]}"]
| Inexact[Array, " {shape[0]} {shape[1]} {shape[2]}"]
):
"""Crop an image or volume to a new shape around its
center.
"""
@@ -74,7 +77,7 @@ def crop_to_shape_with_center(
image: Image,
shape: tuple[int, int],
center: tuple[int, int],
) -> Inexact[Array, " *shape"]:
) -> Inexact[Array, "{shape[0]} {shape[1]}"]:
"""Crop an image to a new shape, given a center."""
if image.ndim != 2:
raise ValueError(
@@ -109,22 +112,25 @@ def pad_to_shape(
image_or_volume: Image,
shape: tuple[int, int],
**kwargs: Any,
) -> Inexact[Array, " *shape"]: ...
) -> Inexact[Array, " {shape[0]} {shape[1]}"]: ...


@overload
def pad_to_shape(
image_or_volume: Volume,
shape: tuple[int, int, int],
**kwargs: Any,
) -> Inexact[Array, " *shape"]: ...
) -> Inexact[Array, " {shape[0]} {shape[1]} {shape[2]}"]: ...


def pad_to_shape(
image_or_volume: Image | Volume,
shape: tuple[int, int] | tuple[int, int, int],
**kwargs: Any,
) -> Inexact[Array, " *shape"]:
) -> (
Inexact[Array, " {shape[0]} {shape[1]}"]
| Inexact[Array, " {shape[0]} {shape[1]} {shape[2]}"]
):
"""Pad an image or volume to a new shape."""
if image_or_volume.ndim not in [2, 3]:
raise ValueError(
@@ -165,7 +171,7 @@ def pad_to_shape(

def resize_with_crop_or_pad(
image: Image, shape: tuple[int, int], **kwargs
) -> Inexact[Array, " *shape"]:
) -> Inexact[Array, " {shape[0]} {shape[1]}"]:
"""Resize an image to a new shape using padding and cropping."""
if image.ndim != 2 or len(shape) != 2:
raise ValueError(
3 changes: 1 addition & 2 deletions src/cryojax/image/operators/_real_operator.py
Original file line number Diff line number Diff line change
@@ -54,8 +54,7 @@ def __call__( # pyright: ignore


class Gaussian2D(AbstractRealOperator, strict=True):
r"""
This operator represents a simple gaussian.
r"""This operator represents a simple gaussian.
Specifically, this is
$$g(r) = \frac{\kappa}{2\pi \beta} \exp(- (r - r_0)^2 / (2 \beta))$$
36 changes: 20 additions & 16 deletions src/cryojax/simulator/_config.py
Original file line number Diff line number Diff line change
@@ -44,17 +44,17 @@ class ImageConfig(Module, strict=True):
- `rescale_method`:
The interpolation method for pixel size rescaling. See
``jax.image.scale_and_translate`` for options.
- `wrapped_frequency_grid`:
- `wrapped_frequency_grid_in_pixels`:
The fourier wavevectors in the imaging plane, wrapped in
a `FrequencyGrid` object.
- `wrapped_padded_frequency_grid`:
- `wrapped_padded_frequency_grid_in_pixels`:
The fourier wavevectors in the imaging plane
in the padded coordinate system, wrapped in
a `FrequencyGrid` object.
- `wrapped_coordinate_grid`:
- `wrapped_coordinate_grid_in_pixels`:
The coordinates in the imaging plane, wrapped
in a `CoordinateGrid` object.
- `wrapped_padded_coordinate_grid`:
- `wrapped_padded_coordinate_grid_in_pixels`:
The coordinates in the imaging plane
in the padded coordinate system, wrapped in a
`CoordinateGrid` object.
@@ -67,10 +67,10 @@ class ImageConfig(Module, strict=True):
pad_mode: Union[str, Callable] = field(static=True)
rescale_method: str = field(static=True)

wrapped_frequency_grid: FrequencyGrid
wrapped_padded_frequency_grid: FrequencyGrid
wrapped_coordinate_grid: CoordinateGrid
wrapped_padded_coordinate_grid: CoordinateGrid
wrapped_frequency_grid_in_pixels: FrequencyGrid
wrapped_padded_frequency_grid_in_pixels: FrequencyGrid
wrapped_coordinate_grid_in_pixels: CoordinateGrid
wrapped_padded_coordinate_grid_in_pixels: CoordinateGrid

def __init__(
self,
@@ -99,10 +99,14 @@ def __init__(
else:
self.padded_shape = padded_shape
# Set coordinates
self.wrapped_frequency_grid = FrequencyGrid(shape=self.shape)
self.wrapped_padded_frequency_grid = FrequencyGrid(shape=self.padded_shape)
self.wrapped_coordinate_grid = CoordinateGrid(shape=self.shape)
self.wrapped_padded_coordinate_grid = CoordinateGrid(shape=self.padded_shape)
self.wrapped_frequency_grid_in_pixels = FrequencyGrid(shape=self.shape)
self.wrapped_padded_frequency_grid_in_pixels = FrequencyGrid(
shape=self.padded_shape
)
self.wrapped_coordinate_grid_in_pixels = CoordinateGrid(shape=self.shape)
self.wrapped_padded_coordinate_grid_in_pixels = CoordinateGrid(
shape=self.padded_shape
)

def __check_init__(self):
if self.padded_shape[0] < self.shape[0] or self.padded_shape[1] < self.shape[1]:
@@ -113,19 +117,19 @@ def __check_init__(self):

@cached_property
def wrapped_coordinate_grid_in_angstroms(self) -> CoordinateGrid:
return self.pixel_size * self.wrapped_coordinate_grid # type: ignore
return self.pixel_size * self.wrapped_coordinate_grid_in_pixels # type: ignore

@cached_property
def wrapped_frequency_grid_in_angstroms(self) -> FrequencyGrid:
return self.wrapped_frequency_grid / self.pixel_size
return self.wrapped_frequency_grid_in_pixels / self.pixel_size

@cached_property
def wrapped_padded_coordinate_grid_in_angstroms(self) -> CoordinateGrid:
return self.pixel_size * self.wrapped_padded_coordinate_grid # type: ignore
return self.pixel_size * self.wrapped_padded_coordinate_grid_in_pixels # type: ignore

@cached_property
def wrapped_padded_frequency_grid_in_angstroms(self) -> FrequencyGrid:
return self.wrapped_padded_frequency_grid / self.pixel_size
return self.wrapped_padded_frequency_grid_in_pixels / self.pixel_size

def rescale_to_pixel_size(
self,
22 changes: 14 additions & 8 deletions src/cryojax/simulator/_detector.py
Original file line number Diff line number Diff line change
@@ -28,14 +28,18 @@ class AbstractDQE(AbstractFourierOperator, strict=True):
@abstractmethod
def __call__(
self,
frequency_grid_maybe_in_angstroms: ImageCoords,
frequency_grid_in_angstroms_or_pixels: ImageCoords,
*,
pixel_size: Optional[RealNumber] = None,
) -> RealImage | RealNumber:
"""**Arguments:**
- `frequency_grid_maybe_in_angstroms`: A frequency grid given in units of
nyquist.
- `frequency_grid_in_angstroms_or_pixels`: A frequency grid
given in angstroms
or pixels. If given
in angstroms, `pixel_size`
must be passed
- `pixel_size`: The pixel size of `frequency_grid_in_angstroms_or_pixels`.
"""
raise NotImplementedError

@@ -51,7 +55,7 @@ def __init__(self):
@override
def __call__(
self,
frequency_grid_maybe_in_angstroms: ImageCoords,
frequency_grid_in_angstroms_or_pixels: ImageCoords,
*,
pixel_size: Optional[RealNumber] = None,
) -> RealNumber:
@@ -72,15 +76,17 @@ class IdealDQE(AbstractDQE, strict=True):
@override
def __call__(
self,
frequency_grid_maybe_in_angstroms: ImageCoords,
frequency_grid_in_angstroms_or_pixels: ImageCoords,
*,
pixel_size: Optional[RealNumber] = None,
) -> RealImage:
if pixel_size is None:
frequency_grid_in_nyquist_units = frequency_grid_maybe_in_angstroms / 0.5
frequency_grid_in_nyquist_units = (
frequency_grid_in_angstroms_or_pixels / 0.5
)
else:
frequency_grid_in_nyquist_units = (
frequency_grid_maybe_in_angstroms * pixel_size
frequency_grid_in_angstroms_or_pixels * pixel_size
) / 0.5
return (
self.fraction_detected_electrons**2
@@ -115,7 +121,7 @@ def __call__(
) -> Complex[Array, "{config.padded_y_dim} {config.padded_x_dim//2+1}"]:
"""Pass the image through the detector model."""
N_pix = np.prod(config.padded_shape)
frequency_grid = config.wrapped_padded_frequency_grid.get()
frequency_grid = config.wrapped_padded_frequency_grid_in_pixels.get()
# Compute the time-integrated electron flux in pixels
electrons_per_pixel = dose.electrons_per_angstrom_squared * config.pixel_size**2
# ... now the total number of electrons over the entire image
Original file line number Diff line number Diff line change
@@ -55,7 +55,7 @@ def __call__(
"""Compute a projection of the real-space potential by extracting
a central slice in fourier space.
"""
frequency_slice = potential.wrapped_frequency_slice.get()
frequency_slice = potential.wrapped_frequency_slice_in_pixels.get()
N = frequency_slice.shape[1]
if potential.shape != (N, N, N):
raise AttributeError(
6 changes: 4 additions & 2 deletions src/cryojax/simulator/_integrators/_nufft_project.py
Original file line number Diff line number Diff line change
@@ -41,14 +41,16 @@ def __call__(
shape = potential.shape
fourier_projection = project_with_nufft(
potential.real_voxel_grid.ravel(),
potential.wrapped_coordinate_grid.get().reshape((math.prod(shape), 3)),
potential.wrapped_coordinate_grid_in_pixels.get().reshape(
(math.prod(shape), 3)
),
config.padded_shape,
eps=self.eps,
)
elif isinstance(potential, RealVoxelCloudPotential):
fourier_projection = project_with_nufft(
potential.voxel_weights,
potential.wrapped_coordinate_list.get(),
potential.wrapped_coordinate_list_in_pixels.get(),
config.padded_shape,
eps=self.eps,
)
2 changes: 1 addition & 1 deletion src/cryojax/simulator/_pipeline.py
Original file line number Diff line number Diff line change
@@ -120,7 +120,7 @@ def crop_and_apply_operators(
if (
self.filter is not None
and self.filter.buffer.shape
== config.wrapped_padded_frequency_grid.get().shape[0:2]
== config.wrapped_padded_frequency_grid_in_pixels.get().shape[0:2]
):
# ... apply the filter here if it is the same size as the padded
# coordinates
90 changes: 48 additions & 42 deletions src/cryojax/simulator/_potential/_voxel_potential.py
Original file line number Diff line number Diff line change
@@ -87,27 +87,29 @@ class AbstractFourierVoxelGridPotential(AbstractVoxelPotential, strict=True):
in fourier-space.
"""

wrapped_frequency_slice: AbstractVar[FrequencySlice]
wrapped_frequency_slice_in_pixels: AbstractVar[FrequencySlice]

@abstractmethod
def __init__(
self,
fourier_voxel_grid: Shaped[ComplexCubicVolume, "..."],
wrapped_frequency_slice: FrequencySlice,
wrapped_frequency_slice_in_pixels: FrequencySlice,
voxel_size: Shaped[RealNumber, "..."] | float,
):
raise NotImplementedError

@cached_property
def wrapped_frequency_slice_in_angstroms(self) -> FrequencySlice:
"""The `wrapped_frequency_slice` in angstroms."""
return self.wrapped_frequency_slice / self.voxel_size
return self.wrapped_frequency_slice_in_pixels / self.voxel_size

def rotate_to_pose(self, pose: AbstractPose) -> Self:
return eqx.tree_at(
lambda d: d.wrapped_frequency_slice.array,
lambda d: d.wrapped_frequency_slice_in_pixels.array,
self,
pose.rotate_coordinates(self.wrapped_frequency_slice.get(), inverse=True),
pose.rotate_coordinates(
self.wrapped_frequency_slice_in_pixels.get(), inverse=True
),
)

@classmethod
@@ -213,7 +215,7 @@ class FourierVoxelGridPotential(AbstractFourierVoxelGridPotential):
"""A 3D scattering potential voxel grid in fourier-space."""

fourier_voxel_grid: Shaped[ComplexCubicVolume, "..."]
wrapped_frequency_slice: FrequencySlice
wrapped_frequency_slice_in_pixels: FrequencySlice
voxel_size: Shaped[RealNumber, "..."] = field(converter=error_if_not_positive)

is_real: ClassVar[bool] = False
@@ -222,18 +224,18 @@ class FourierVoxelGridPotential(AbstractFourierVoxelGridPotential):
def __init__(
self,
fourier_voxel_grid: Shaped[ComplexCubicVolume, "..."],
wrapped_frequency_slice: FrequencySlice,
wrapped_frequency_slice_in_pixels: FrequencySlice,
voxel_size: Shaped[RealNumber, "..."] | float,
):
"""**Arguments:**
- `fourier_voxel_grid`: The cubic voxel grid in fourier space.
- `wrapped_frequency_slice`: Frequency slice coordinate system,
wrapped in a `FrequencySlice` object.
- `wrapped_frequency_slice_in_pixels`: Frequency slice coordinate system,
wrapped in a `FrequencySlice` object.
- `voxel_size`: The voxel size.
"""
self.fourier_voxel_grid = jnp.asarray(fourier_voxel_grid)
self.wrapped_frequency_slice = wrapped_frequency_slice
self.wrapped_frequency_slice_in_pixels = wrapped_frequency_slice_in_pixels
self.voxel_size = jnp.asarray(voxel_size)

@property
@@ -247,15 +249,15 @@ class FourierVoxelGridPotentialInterpolator(AbstractFourierVoxelGridPotential):
"""

coefficients: Shaped[ComplexCubicVolume, "..."]
wrapped_frequency_slice: FrequencySlice
wrapped_frequency_slice_in_pixels: FrequencySlice
voxel_size: Shaped[RealNumber, "..."] = field(converter=error_if_not_positive)

is_real: ClassVar[bool] = False

def __init__(
self,
fourier_voxel_grid: Shaped[ComplexCubicVolume, "..."],
wrapped_frequency_slice: FrequencySlice,
wrapped_frequency_slice_in_pixels: FrequencySlice,
voxel_size: Shaped[RealNumber, "..."] | float,
):
"""
@@ -275,8 +277,8 @@ def __init__(
**Arguments:**
- `fourier_voxel_grid`: The cubic voxel grid in fourier space.
- `wrapped_frequency_slice`: Frequency slice coordinate system,
wrapped in a `FrequencySlice` object.
- `wrapped_frequency_slice_in_pixels`: Frequency slice coordinate system,
wrapped in a `FrequencySlice` object.
- `voxel_size`: The voxel size.
"""
n_batch_dims = fourier_voxel_grid.ndim - 3
@@ -286,7 +288,7 @@ def __init__(
self.coefficients = compute_spline_coefficients_3d(
jnp.asarray(fourier_voxel_grid)
)
self.wrapped_frequency_slice = wrapped_frequency_slice
self.wrapped_frequency_slice_in_pixels = wrapped_frequency_slice_in_pixels
self.voxel_size = jnp.asarray(voxel_size)

@property
@@ -300,26 +302,26 @@ class RealVoxelGridPotential(AbstractVoxelPotential, strict=True):
"""Abstraction of a 3D scattering potential voxel grid in real-space."""

real_voxel_grid: Shaped[RealCubicVolume, "..."]
wrapped_coordinate_grid: CoordinateGrid
wrapped_coordinate_grid_in_pixels: CoordinateGrid
voxel_size: Shaped[RealNumber, "..."] = field(converter=error_if_not_positive)

is_real: ClassVar[bool] = True

def __init__(
self,
real_voxel_grid: Shaped[RealCubicVolume, "..."],
wrapped_coordinate_grid: CoordinateGrid,
wrapped_coordinate_grid_in_pixels: CoordinateGrid,
voxel_size: Shaped[RealNumber, "..."] | float,
):
"""**Arguments:**
- `real_voxel_grid`: The voxel grid in fourier space.
- `wrapped_coordinate_grid`: A coordinate grid, wrapped into a
`CoordinateGrid` object.
- `wrapped_coordinate_grid_in_pixels`: A coordinate grid, wrapped into a
`CoordinateGrid` object.
- `voxel_size`: The voxel size.
"""
self.real_voxel_grid = jnp.asarray(real_voxel_grid)
self.wrapped_coordinate_grid = wrapped_coordinate_grid
self.wrapped_coordinate_grid_in_pixels = wrapped_coordinate_grid_in_pixels
self.voxel_size = jnp.asarray(voxel_size)

@property
@@ -329,13 +331,15 @@ def shape(self) -> tuple[int, int, int]:
@cached_property
def wrapped_coordinate_grid_in_angstroms(self) -> CoordinateGrid:
"""The `coordinate_grid` in angstroms."""
return self.voxel_size * self.wrapped_coordinate_grid # type: ignore
return self.voxel_size * self.wrapped_coordinate_grid_in_pixels # type: ignore

def rotate_to_pose(self, pose: AbstractPose) -> Self:
return eqx.tree_at(
lambda d: d.wrapped_coordinate_grid.array,
lambda d: d.wrapped_coordinate_grid_in_pixels.array,
self,
pose.rotate_coordinates(self.wrapped_coordinate_grid.get(), inverse=False),
pose.rotate_coordinates(
self.wrapped_coordinate_grid_in_pixels.get(), inverse=False
),
)

@overload
@@ -364,7 +368,7 @@ def from_real_voxel_grid(
real_voxel_grid: Float[Array, "N N N"] | Float[np.ndarray, "N N N"],
voxel_size: Float[Array, ""] | Float[np.ndarray, ""] | float,
*,
coordinate_grid: Optional[CoordinateGrid] = None,
coordinate_grid_in_pixels: Optional[CoordinateGrid] = None,
crop_scale: Optional[float] = None,
) -> Self:
"""Load a `RealVoxelGridPotential` from a real-valued 3D electron
@@ -383,7 +387,7 @@ def from_real_voxel_grid(
jnp.asarray(voxel_size),
)
# Make coordinates if not given
if coordinate_grid is None:
if coordinate_grid_in_pixels is None:
# Option for cropping template
if crop_scale is not None:
if crop_scale > 1.0:
@@ -393,9 +397,9 @@ def from_real_voxel_grid(
tuple([int(s * crop_scale) for s in real_voxel_grid.shape[-3:]]),
)
real_voxel_grid = crop_to_shape(real_voxel_grid, cropped_shape)
coordinate_grid = CoordinateGrid(real_voxel_grid.shape[-3:])
coordinate_grid_in_pixels = CoordinateGrid(real_voxel_grid.shape[-3:])

return cls(real_voxel_grid, coordinate_grid, voxel_size)
return cls(real_voxel_grid, coordinate_grid_in_pixels, voxel_size)

@classmethod
def from_atoms(
@@ -433,7 +437,7 @@ def from_atoms(
return cls.from_real_voxel_grid(
real_voxel_grid,
voxel_size,
coordinate_grid=coordinate_grid_in_angstroms / voxel_size,
coordinate_grid_in_pixels=coordinate_grid_in_angstroms / voxel_size,
**kwargs,
)

@@ -450,42 +454,44 @@ class RealVoxelCloudPotential(AbstractVoxelPotential, strict=True):
"""

voxel_weights: Shaped[RealPointCloud, "..."]
wrapped_coordinate_list: CoordinateList
wrapped_coordinate_list_in_pixels: CoordinateList
voxel_size: Shaped[RealNumber, "..."] = field(converter=error_if_not_positive)

is_real: ClassVar[bool] = True

def __init__(
self,
voxel_weights: Shaped[RealPointCloud, "..."],
wrapped_coordinate_list: CoordinateList,
wrapped_coordinate_list_in_pixels: CoordinateList,
voxel_size: Shaped[RealNumber, "..."] | float,
):
"""**Arguments:**
- `voxel_weights`: A point-cloud of voxel scattering potential values.
- `wrapped_coordinate_list`: Coordinate list for the `voxel_weights`, wrapped
in a `CoordinateList` object.
- `wrapped_coordinate_list_in_pixels`: Coordinate list for the `voxel_weights`,
wrapped in a `CoordinateList` object.
- `voxel_size`: The voxel size.
"""
self.voxel_weights = jnp.asarray(voxel_weights)
self.wrapped_coordinate_list = wrapped_coordinate_list
self.wrapped_coordinate_list_in_pixels = wrapped_coordinate_list_in_pixels
self.voxel_size = jnp.asarray(voxel_size)

@property
def shape(self) -> tuple[int, int]:
return cast(tuple[int, int], self.voxel_weights.shape)

@cached_property
def coordinate_list_in_angstroms(self) -> CoordinateList:
def wrapped_coordinate_list_in_angstroms(self) -> CoordinateList:
"""The `coordinate_list` in angstroms."""
return self.voxel_size * self.wrapped_coordinate_list # type: ignore
return self.voxel_size * self.wrapped_coordinate_list_in_pixels # type: ignore

def rotate_to_pose(self, pose: AbstractPose) -> Self:
return eqx.tree_at(
lambda d: d.wrapped_coordinate_list.array,
lambda d: d.wrapped_coordinate_list_in_pixels.array,
self,
pose.rotate_coordinates(self.wrapped_coordinate_list.get(), inverse=False),
pose.rotate_coordinates(
self.wrapped_coordinate_list_in_pixels.get(), inverse=False
),
)

@classmethod
@@ -494,7 +500,7 @@ def from_real_voxel_grid(
real_voxel_grid: Float[Array, "N N N"] | Float[np.ndarray, "N N N"],
voxel_size: Float[Array, ""] | Float[np.ndarray, ""] | float,
*,
coordinate_grid: Optional[CoordinateGrid] = None,
coordinate_grid_in_pixels: Optional[CoordinateGrid] = None,
rtol: float = 1e-05,
atol: float = 1e-08,
) -> Self:
@@ -516,13 +522,13 @@ def from_real_voxel_grid(
jnp.asarray(voxel_size),
)
# Make coordinates if not given
if coordinate_grid is None:
coordinate_grid = CoordinateGrid(real_voxel_grid.shape)
if coordinate_grid_in_pixels is None:
coordinate_grid_in_pixels = CoordinateGrid(real_voxel_grid.shape)
# ... mask zeros to store smaller arrays. This
# option is not jittable.
nonzero = jnp.where(~jnp.isclose(real_voxel_grid, 0.0, rtol=rtol, atol=atol))
flat_potential = real_voxel_grid[nonzero]
coordinate_list = CoordinateList(coordinate_grid.get()[nonzero])
coordinate_list = CoordinateList(coordinate_grid_in_pixels.get()[nonzero])

return cls(flat_potential, coordinate_list, voxel_size)

@@ -562,7 +568,7 @@ def from_atoms(
return cls.from_real_voxel_grid(
real_voxel_grid,
voxel_size,
coordinate_grid=coordinate_grid_in_angstroms / voxel_size,
coordinate_grid_in_pixels=coordinate_grid_in_angstroms / voxel_size,
**kwargs,
)

2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -92,7 +92,7 @@ def potential(sample_mrc_path):

@pytest.fixture
def filters(config):
return op.LowpassFilter(config.wrapped_padded_frequency_grid.get())
return op.LowpassFilter(config.wrapped_padded_frequency_grid_in_pixels.get())


@pytest.fixture
18 changes: 12 additions & 6 deletions tests/test_potential.py
Original file line number Diff line number Diff line change
@@ -30,14 +30,20 @@ def test_voxel_electron_potential_loaders():
for potential in [real_potential, fourier_potential, cloud_potential]:
assert potential.voxel_size == jnp.asarray(voxel_size)

assert isinstance(fourier_potential.wrapped_frequency_slice, FrequencySlice)
assert isinstance(
fourier_potential.wrapped_frequency_slice.get(), VolumeSliceCoords
fourier_potential.wrapped_frequency_slice_in_pixels, FrequencySlice
)
assert isinstance(
fourier_potential.wrapped_frequency_slice_in_pixels.get(), VolumeSliceCoords
)
assert isinstance(real_potential.wrapped_coordinate_grid_in_pixels, CoordinateGrid)
assert isinstance(
real_potential.wrapped_coordinate_grid_in_pixels.get(), VolumeCoords
)
assert isinstance(cloud_potential.wrapped_coordinate_list_in_pixels, CoordinateList)
assert isinstance(
cloud_potential.wrapped_coordinate_list_in_pixels.get(), PointCloudCoords3D
)
assert isinstance(real_potential.wrapped_coordinate_grid, CoordinateGrid)
assert isinstance(real_potential.wrapped_coordinate_grid.get(), VolumeCoords)
assert isinstance(cloud_potential.wrapped_coordinate_list, CoordinateList)
assert isinstance(cloud_potential.wrapped_coordinate_list.get(), PointCloudCoords3D)


def test_electron_potential_vmap(potential, integrator, config):
4 changes: 2 additions & 2 deletions tests/test_shape.py
Original file line number Diff line number Diff line change
@@ -17,8 +17,8 @@ def test_fourier_shape(model, request):
model = request.getfixturevalue(model)
image = model.render(get_real=False)
padded_image = model.render(view_cropped=False, get_real=False)
assert image.shape == model.config.wrapped_frequency_grid.get().shape[0:2]
assert image.shape == model.config.wrapped_frequency_grid_in_pixels.get().shape[0:2]
assert (
padded_image.shape
== model.config.wrapped_padded_frequency_grid.get().shape[0:2]
== model.config.wrapped_padded_frequency_grid_in_pixels.get().shape[0:2]
)

0 comments on commit 3b14a41

Please sign in to comment.