Skip to content

Commit

Permalink
Merge pull request #258 from mjo22/multislice-updated
Browse files Browse the repository at this point in the history
First pass at multislice implementation
  • Loading branch information
mjo22 authored Dec 6, 2024
2 parents b4d678d + 389cc4a commit 8de2685
Show file tree
Hide file tree
Showing 15 changed files with 995 additions and 47 deletions.
265 changes: 265 additions & 0 deletions docs/examples/multislice.ipynb

Large diffs are not rendered by default.

16 changes: 7 additions & 9 deletions docs/examples/simulate-image.ipynb

Large diffs are not rendered by default.

18 changes: 16 additions & 2 deletions src/cryojax/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
from ..simulator._potential_integrator import (
EwaldSphereExtraction as EwaldSphereExtraction,
from ..simulator._multislice_integrator import (
AbstractMultisliceIntegrator as AbstractMultisliceIntegrator,
FFTMultisliceIntegrator as FFTMultisliceIntegrator,
)

# from ..simulator._potential_integrator import (
# EwaldSphereExtraction as EwaldSphereExtraction,
# )
from ..simulator._scattering_theory import (
AbstractWaveScatteringTheory as AbstractWaveScatteringTheory,
HighEnergyScatteringTheory as HighEnergyScatteringTheory,
MultisliceScatteringTheory as MultisliceScatteringTheory,
)
from ..simulator._transfer_theory import (
WaveTransferFunction as WaveTransferFunction,
WaveTransferTheory as WaveTransferTheory,
)
8 changes: 4 additions & 4 deletions src/cryojax/image/_rescale_pixel_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import jax
import jax.numpy as jnp
from jax.image import scale_and_translate
from jaxtyping import Array, Complex, Float
from jaxtyping import Array, Complex, Float, Inexact

from ._fft import irfftn, rfftn

Expand All @@ -17,7 +17,7 @@ def rescale_pixel_size(
current_pixel_size: Float[Array, ""],
new_pixel_size: Float[Array, ""],
method: str = "bicubic",
antialias: bool = False,
antialias: bool = True,
) -> Float[Array, "y_dim x_dim"]:
"""
Measure an image at a given pixel size using interpolation.
Expand Down Expand Up @@ -71,7 +71,7 @@ def rescale_pixel_size(

def maybe_rescale_pixel_size(
real_or_fourier_image: (
Float[Array, "padded_y_dim padded_x_dim"]
Inexact[Array, "padded_y_dim padded_x_dim"]
| Complex[Array, "padded_y_dim padded_x_dim//2+1"]
),
current_pixel_size: Float[Array, ""],
Expand All @@ -80,7 +80,7 @@ def maybe_rescale_pixel_size(
shape_in_real_space: Optional[tuple[int, int]] = None,
method: str = "bicubic",
) -> (
Float[Array, "padded_y_dim padded_x_dim"]
Inexact[Array, "padded_y_dim padded_x_dim"]
| Complex[Array, "padded_y_dim padded_x_dim//2+1"]
):
"""Rescale the image pixel size using real-space interpolation. Only
Expand Down
14 changes: 7 additions & 7 deletions src/cryojax/simulator/_instrument_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import jax.numpy as jnp
from equinox import Module
from jaxtyping import Array, Float
from jaxtyping import Array, Float, Inexact

from .._errors import error_if_not_positive
from ..constants import convert_keV_to_angstroms
Expand Down Expand Up @@ -202,20 +202,20 @@ def padded_full_frequency_grid_in_angstroms(
)

def crop_to_shape(
self, image: Float[Array, "y_dim x_dim"]
) -> Float[Array, "{self.y_dim} {self.x_dim}"]:
self, image: Inexact[Array, "y_dim x_dim"]
) -> Inexact[Array, "{self.y_dim} {self.x_dim}"]:
"""Crop an image to `shape`."""
return crop_to_shape(image, self.shape)

def pad_to_padded_shape(
self, image: Float[Array, "y_dim x_dim"], **kwargs: Any
) -> Float[Array, "{self.padded_y_dim} {self.padded_x_dim}"]:
self, image: Inexact[Array, "y_dim x_dim"], **kwargs: Any
) -> Inexact[Array, "{self.padded_y_dim} {self.padded_x_dim}"]:
"""Pad an image to `padded_shape`."""
return pad_to_shape(image, self.padded_shape, mode=self.pad_mode, **kwargs)

def crop_or_pad_to_padded_shape(
self, image: Float[Array, "y_dim x_dim"], **kwargs: Any
) -> Float[Array, "{self.padded_y_dim} {self.padded_x_dim}"]:
self, image: Inexact[Array, "y_dim x_dim"], **kwargs: Any
) -> Inexact[Array, "{self.padded_y_dim} {self.padded_x_dim}"]:
"""Reshape an image to `padded_shape` using cropping or padding."""
return resize_with_crop_or_pad(
image, self.padded_shape, mode=self.pad_mode, **kwargs
Expand Down
4 changes: 4 additions & 0 deletions src/cryojax/simulator/_multislice_integrator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .base_multislice_integrator import (
AbstractMultisliceIntegrator as AbstractMultisliceIntegrator,
)
from .fft_multislice_integrator import FFTMultisliceIntegrator as FFTMultisliceIntegrator
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from abc import abstractmethod
from typing import Generic, TypeVar

from equinox import Module
from jaxtyping import Array, Complex

from .._instrument_config import InstrumentConfig


PotentialT = TypeVar("PotentialT")


class AbstractMultisliceIntegrator(Module, Generic[PotentialT], strict=True):
"""Base class for a multi-slice integration scheme."""

@abstractmethod
def compute_wavefunction_at_exit_plane(
self,
potential: PotentialT,
instrument_config: InstrumentConfig,
) -> Complex[
Array, "{instrument_config.padded_y_dim} {instrument_config.padded_x_dim}"
]:
raise NotImplementedError
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
from typing import Any
from typing_extensions import override

import jax
import jax.numpy as jnp
from equinox import error_if
from jaxtyping import Array, Complex

# from cryojax.coordinates import make_frequency_grid
from cryojax.image import fftn, ifftn

from .._instrument_config import InstrumentConfig
from .._potential_representation import (
AbstractAtomicPotential,
)

# , RealVoxelGridPotential
from .._scattering_theory import compute_phase_shifts_from_integrated_potential
from .base_multislice_integrator import AbstractMultisliceIntegrator


class FFTMultisliceIntegrator(
AbstractMultisliceIntegrator[AbstractAtomicPotential], # | RealVoxelGridPotential],
strict=True,
):
"""Multislice integrator that steps using successive FFT-based convolutions."""

slice_thickness_in_voxels: int
# interpolation_order: int
options_for_rasterization: dict[str, Any]

def __init__(
self,
slice_thickness_in_voxels: int = 1,
*,
# interpolation_order: int = 1,
options_for_rasterization: dict[str, Any] = {},
):
"""**Arguments:**
- `slice_thickness_in_voxels`:
The number of slices to step through per iteration of the
rasterized voxel grid.
- `interpolation_order`:
The interpolation order. This can be `0` (nearest-neighbor), `1`
(linear), or `3` (cubic). See `cryojax.image.map_coordinates` for
documentation. Ignored if an `AbstractAtomicPotential` is passed.
- `options_for_rasterization`:
See `cryojax.simulator.AbstractAtomicPotential.as_real_voxel_grid`
for documentation. Ignored if a `RealVoxelGridPotential` is passed.
"""
if slice_thickness_in_voxels < 1:
raise AttributeError(
"FFTMultisliceIntegrator.slice_thickness_in_voxels must be an "
"integer greater than or equal to 1."
)
self.slice_thickness_in_voxels = slice_thickness_in_voxels
# self.interpolation_order = interpolation_order
self.options_for_rasterization = options_for_rasterization

@override
def compute_wavefunction_at_exit_plane(
self,
potential: AbstractAtomicPotential, # | RealVoxelGridPotential,
instrument_config: InstrumentConfig,
) -> Complex[
Array, "{instrument_config.padded_y_dim} {instrument_config.padded_x_dim}"
]:
"""Compute the exit wave from an atomic potential using the multislice
method.
**Arguments:**
- `potential`: The atomic potential to project.
- `instrument_config`: The configuration of the imaging instrument.
**Returns:**
The wavefunction in the exit plane of the specimen.
""" # noqa: E501
# Rasterize a voxel grid at the given settings
z_dim, y_dim, x_dim = (
min(instrument_config.padded_shape),
*instrument_config.padded_shape,
)
voxel_size = instrument_config.pixel_size
potential_voxel_grid = potential.as_real_voxel_grid(
(z_dim, y_dim, x_dim), voxel_size, **self.options_for_rasterization
)
# if isinstance(potential, AbstractAtomicPotential):
# z_dim, y_dim, x_dim = (
# min(instrument_config.padded_shape),
# *instrument_config.padded_shape,
# )
# voxel_size = instrument_config.pixel_size
# potential_voxel_grid = potential.as_real_voxel_grid(
# (z_dim, y_dim, x_dim), voxel_size, **self.options_for_rasterization
# )
# else:
# # Interpolate volume to new pose at given coordinate system
# z_dim, y_dim, x_dim = potential.real_voxel_grid.shape
# voxel_size = potential.voxel_size
# potential_voxel_grid = _interpolate_voxel_grid_to_rotated_coordinates(
# potential.real_voxel_grid,
# potential.coordinate_grid_in_pixels,
# self.interpolation_order,
# )
# Initialize multislice geometry
n_slices = z_dim // self.slice_thickness_in_voxels
slice_thickness = voxel_size * self.slice_thickness_in_voxels
# Locally average the potential to be at the given slice thickness.
# Thow away some slices equal to the remainder
# `dim % self.slice_thickness_in_voxels`
if self.slice_thickness_in_voxels > 1:
potential_per_slice = jnp.sum(
potential_voxel_grid[
: z_dim - z_dim % self.slice_thickness_in_voxels, ...
].reshape((self.slice_thickness_in_voxels, n_slices, y_dim, x_dim)),
axis=0,
)
# ... take care of remainder
if z_dim % self.slice_thickness_in_voxels != 0:
potential_per_slice = jnp.concatenate(
(
potential_per_slice,
potential_voxel_grid[
z_dim - z_dim % self.slice_thickness_in_voxels :, ...
],
)
)
else:
potential_per_slice = potential_voxel_grid
# Compute the integrated potential in a given slice interval, multiplying by
# the slice thickness (TODO: interpolate for different slice thicknesses?)
integrated_potential_per_slice = potential_per_slice * voxel_size
phase_shifts_per_slice = jax.vmap(
compute_phase_shifts_from_integrated_potential, in_axes=[0, None]
)(integrated_potential_per_slice, instrument_config.wavelength_in_angstroms)
# Compute the transmission function
transmission = jnp.exp(1.0j * phase_shifts_per_slice)
# Compute the fresnel propagator (TODO: check numerical factors)
radial_frequency_grid = jnp.sum(
instrument_config.padded_full_frequency_grid_in_angstroms**2,
axis=-1,
)
# if isinstance(potential, AbstractAtomicPotential):
# radial_frequency_grid = jnp.sum(
# instrument_config.padded_full_frequency_grid_in_angstroms**2,
# axis=-1,
# )
# else:
# radial_frequency_grid = jnp.sum(
# make_frequency_grid((y_dim, x_dim), voxel_size, half_space=False) ** 2,
# axis=-1,
# )
fresnel_propagator = jnp.exp(
1.0j
* jnp.pi
* instrument_config.wavelength_in_angstroms
* radial_frequency_grid
* slice_thickness
)
# Prepare for iteration. First, initialize plane wave
plane_wave = jnp.ones((y_dim, x_dim), dtype=complex)
# ... stepping function
make_step = lambda n, last_exit_wave: ifftn(
fftn(transmission[n, :, :] * last_exit_wave) * fresnel_propagator
)
# Compute exit wave
exit_wave = jax.lax.fori_loop(0, n_slices, make_step, plane_wave)

# return (
# exit_wave
# if isinstance(potential, AbstractAtomicPotential)
# else self._postprocess_exit_wave_for_voxel_potential(
# exit_wave, potential, instrument_config
# )
# )

return exit_wave

def _postprocess_exit_wave_for_voxel_potential(
self,
exit_wave: Complex[Array, "_ _"],
potential,
instrument_config: InstrumentConfig,
) -> Complex[
Array, "{instrument_config.padded_y_dim} {instrument_config.padded_x_dim}"
]:
# Check exit wave is at correct pixel size
exit_wave = error_if(
exit_wave,
~jnp.isclose(potential.voxel_size, instrument_config.pixel_size),
f"Tried to use {type(self).__name__} with `{type(potential).__name__}."
"voxel_size != InstrumentConfig.pixel_size`. If this is true, then "
f"`{type(self).__name__}.pixel_rescaling_method` must not be set to "
f"`None`. Try setting `{type(self).__name__}.pixel_rescaling_method = "
"'bicubic'`.",
)
# Resize the image to match the InstrumentConfig.padded_shape
if instrument_config.padded_shape != exit_wave.shape:
exit_wave = instrument_config.crop_or_pad_to_padded_shape(
exit_wave, constant_values=1.0 + 0.0j
)

return exit_wave


# def _interpolate_voxel_grid_to_rotated_coordinates(
# real_voxel_grid,
# coordinate_grid_in_pixels,
# interpolation_order,
# ):
# # Convert to logical coordinates
# z_dim, y_dim, x_dim = real_voxel_grid.shape
# logical_coordinate_grid = (
# coordinate_grid_in_pixels
# + jnp.asarray((x_dim // 2, y_dim // 2, z_dim // 2))[None, None, None, :]
# )
# # Convert arguments to map_coordinates convention and compute
# x, y, z = jnp.transpose(logical_coordinate_grid, axes=[3, 0, 1, 2])
# return map_coordinates(
# real_voxel_grid, (z, y, x), order=interpolation_order, mode="fill", cval=0.0
# )
11 changes: 10 additions & 1 deletion src/cryojax/simulator/_scattering_theory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from .base_scattering_theory import AbstractScatteringTheory as AbstractScatteringTheory
from .base_scattering_theory import (
AbstractScatteringTheory as AbstractScatteringTheory,
AbstractWaveScatteringTheory as AbstractWaveScatteringTheory,
)
from .common_functions import (
compute_phase_shifts_from_integrated_potential as compute_phase_shifts_from_integrated_potential, # noqa: E501
)
from .high_energy_scattering_theory import (
HighEnergyScatteringTheory as HighEnergyScatteringTheory,
)
from .multislice_scattering_theory import (
MultisliceScatteringTheory as MultisliceScatteringTheory,
)
from .weak_phase_scattering_theory import (
AbstractWeakPhaseScatteringTheory as AbstractWeakPhaseScatteringTheory,
LinearSuperpositionScatteringTheory as LinearSuperpositionScatteringTheory,
Expand Down
Loading

0 comments on commit 8de2685

Please sign in to comment.