-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #258 from mjo22/multislice-updated
First pass at multislice implementation
- Loading branch information
Showing
15 changed files
with
995 additions
and
47 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,17 @@ | ||
from ..simulator._potential_integrator import ( | ||
EwaldSphereExtraction as EwaldSphereExtraction, | ||
from ..simulator._multislice_integrator import ( | ||
AbstractMultisliceIntegrator as AbstractMultisliceIntegrator, | ||
FFTMultisliceIntegrator as FFTMultisliceIntegrator, | ||
) | ||
|
||
# from ..simulator._potential_integrator import ( | ||
# EwaldSphereExtraction as EwaldSphereExtraction, | ||
# ) | ||
from ..simulator._scattering_theory import ( | ||
AbstractWaveScatteringTheory as AbstractWaveScatteringTheory, | ||
HighEnergyScatteringTheory as HighEnergyScatteringTheory, | ||
MultisliceScatteringTheory as MultisliceScatteringTheory, | ||
) | ||
from ..simulator._transfer_theory import ( | ||
WaveTransferFunction as WaveTransferFunction, | ||
WaveTransferTheory as WaveTransferTheory, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .base_multislice_integrator import ( | ||
AbstractMultisliceIntegrator as AbstractMultisliceIntegrator, | ||
) | ||
from .fft_multislice_integrator import FFTMultisliceIntegrator as FFTMultisliceIntegrator |
24 changes: 24 additions & 0 deletions
24
src/cryojax/simulator/_multislice_integrator/base_multislice_integrator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from abc import abstractmethod | ||
from typing import Generic, TypeVar | ||
|
||
from equinox import Module | ||
from jaxtyping import Array, Complex | ||
|
||
from .._instrument_config import InstrumentConfig | ||
|
||
|
||
PotentialT = TypeVar("PotentialT") | ||
|
||
|
||
class AbstractMultisliceIntegrator(Module, Generic[PotentialT], strict=True): | ||
"""Base class for a multi-slice integration scheme.""" | ||
|
||
@abstractmethod | ||
def compute_wavefunction_at_exit_plane( | ||
self, | ||
potential: PotentialT, | ||
instrument_config: InstrumentConfig, | ||
) -> Complex[ | ||
Array, "{instrument_config.padded_y_dim} {instrument_config.padded_x_dim}" | ||
]: | ||
raise NotImplementedError |
224 changes: 224 additions & 0 deletions
224
src/cryojax/simulator/_multislice_integrator/fft_multislice_integrator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
from typing import Any | ||
from typing_extensions import override | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from equinox import error_if | ||
from jaxtyping import Array, Complex | ||
|
||
# from cryojax.coordinates import make_frequency_grid | ||
from cryojax.image import fftn, ifftn | ||
|
||
from .._instrument_config import InstrumentConfig | ||
from .._potential_representation import ( | ||
AbstractAtomicPotential, | ||
) | ||
|
||
# , RealVoxelGridPotential | ||
from .._scattering_theory import compute_phase_shifts_from_integrated_potential | ||
from .base_multislice_integrator import AbstractMultisliceIntegrator | ||
|
||
|
||
class FFTMultisliceIntegrator( | ||
AbstractMultisliceIntegrator[AbstractAtomicPotential], # | RealVoxelGridPotential], | ||
strict=True, | ||
): | ||
"""Multislice integrator that steps using successive FFT-based convolutions.""" | ||
|
||
slice_thickness_in_voxels: int | ||
# interpolation_order: int | ||
options_for_rasterization: dict[str, Any] | ||
|
||
def __init__( | ||
self, | ||
slice_thickness_in_voxels: int = 1, | ||
*, | ||
# interpolation_order: int = 1, | ||
options_for_rasterization: dict[str, Any] = {}, | ||
): | ||
"""**Arguments:** | ||
- `slice_thickness_in_voxels`: | ||
The number of slices to step through per iteration of the | ||
rasterized voxel grid. | ||
- `interpolation_order`: | ||
The interpolation order. This can be `0` (nearest-neighbor), `1` | ||
(linear), or `3` (cubic). See `cryojax.image.map_coordinates` for | ||
documentation. Ignored if an `AbstractAtomicPotential` is passed. | ||
- `options_for_rasterization`: | ||
See `cryojax.simulator.AbstractAtomicPotential.as_real_voxel_grid` | ||
for documentation. Ignored if a `RealVoxelGridPotential` is passed. | ||
""" | ||
if slice_thickness_in_voxels < 1: | ||
raise AttributeError( | ||
"FFTMultisliceIntegrator.slice_thickness_in_voxels must be an " | ||
"integer greater than or equal to 1." | ||
) | ||
self.slice_thickness_in_voxels = slice_thickness_in_voxels | ||
# self.interpolation_order = interpolation_order | ||
self.options_for_rasterization = options_for_rasterization | ||
|
||
@override | ||
def compute_wavefunction_at_exit_plane( | ||
self, | ||
potential: AbstractAtomicPotential, # | RealVoxelGridPotential, | ||
instrument_config: InstrumentConfig, | ||
) -> Complex[ | ||
Array, "{instrument_config.padded_y_dim} {instrument_config.padded_x_dim}" | ||
]: | ||
"""Compute the exit wave from an atomic potential using the multislice | ||
method. | ||
**Arguments:** | ||
- `potential`: The atomic potential to project. | ||
- `instrument_config`: The configuration of the imaging instrument. | ||
**Returns:** | ||
The wavefunction in the exit plane of the specimen. | ||
""" # noqa: E501 | ||
# Rasterize a voxel grid at the given settings | ||
z_dim, y_dim, x_dim = ( | ||
min(instrument_config.padded_shape), | ||
*instrument_config.padded_shape, | ||
) | ||
voxel_size = instrument_config.pixel_size | ||
potential_voxel_grid = potential.as_real_voxel_grid( | ||
(z_dim, y_dim, x_dim), voxel_size, **self.options_for_rasterization | ||
) | ||
# if isinstance(potential, AbstractAtomicPotential): | ||
# z_dim, y_dim, x_dim = ( | ||
# min(instrument_config.padded_shape), | ||
# *instrument_config.padded_shape, | ||
# ) | ||
# voxel_size = instrument_config.pixel_size | ||
# potential_voxel_grid = potential.as_real_voxel_grid( | ||
# (z_dim, y_dim, x_dim), voxel_size, **self.options_for_rasterization | ||
# ) | ||
# else: | ||
# # Interpolate volume to new pose at given coordinate system | ||
# z_dim, y_dim, x_dim = potential.real_voxel_grid.shape | ||
# voxel_size = potential.voxel_size | ||
# potential_voxel_grid = _interpolate_voxel_grid_to_rotated_coordinates( | ||
# potential.real_voxel_grid, | ||
# potential.coordinate_grid_in_pixels, | ||
# self.interpolation_order, | ||
# ) | ||
# Initialize multislice geometry | ||
n_slices = z_dim // self.slice_thickness_in_voxels | ||
slice_thickness = voxel_size * self.slice_thickness_in_voxels | ||
# Locally average the potential to be at the given slice thickness. | ||
# Thow away some slices equal to the remainder | ||
# `dim % self.slice_thickness_in_voxels` | ||
if self.slice_thickness_in_voxels > 1: | ||
potential_per_slice = jnp.sum( | ||
potential_voxel_grid[ | ||
: z_dim - z_dim % self.slice_thickness_in_voxels, ... | ||
].reshape((self.slice_thickness_in_voxels, n_slices, y_dim, x_dim)), | ||
axis=0, | ||
) | ||
# ... take care of remainder | ||
if z_dim % self.slice_thickness_in_voxels != 0: | ||
potential_per_slice = jnp.concatenate( | ||
( | ||
potential_per_slice, | ||
potential_voxel_grid[ | ||
z_dim - z_dim % self.slice_thickness_in_voxels :, ... | ||
], | ||
) | ||
) | ||
else: | ||
potential_per_slice = potential_voxel_grid | ||
# Compute the integrated potential in a given slice interval, multiplying by | ||
# the slice thickness (TODO: interpolate for different slice thicknesses?) | ||
integrated_potential_per_slice = potential_per_slice * voxel_size | ||
phase_shifts_per_slice = jax.vmap( | ||
compute_phase_shifts_from_integrated_potential, in_axes=[0, None] | ||
)(integrated_potential_per_slice, instrument_config.wavelength_in_angstroms) | ||
# Compute the transmission function | ||
transmission = jnp.exp(1.0j * phase_shifts_per_slice) | ||
# Compute the fresnel propagator (TODO: check numerical factors) | ||
radial_frequency_grid = jnp.sum( | ||
instrument_config.padded_full_frequency_grid_in_angstroms**2, | ||
axis=-1, | ||
) | ||
# if isinstance(potential, AbstractAtomicPotential): | ||
# radial_frequency_grid = jnp.sum( | ||
# instrument_config.padded_full_frequency_grid_in_angstroms**2, | ||
# axis=-1, | ||
# ) | ||
# else: | ||
# radial_frequency_grid = jnp.sum( | ||
# make_frequency_grid((y_dim, x_dim), voxel_size, half_space=False) ** 2, | ||
# axis=-1, | ||
# ) | ||
fresnel_propagator = jnp.exp( | ||
1.0j | ||
* jnp.pi | ||
* instrument_config.wavelength_in_angstroms | ||
* radial_frequency_grid | ||
* slice_thickness | ||
) | ||
# Prepare for iteration. First, initialize plane wave | ||
plane_wave = jnp.ones((y_dim, x_dim), dtype=complex) | ||
# ... stepping function | ||
make_step = lambda n, last_exit_wave: ifftn( | ||
fftn(transmission[n, :, :] * last_exit_wave) * fresnel_propagator | ||
) | ||
# Compute exit wave | ||
exit_wave = jax.lax.fori_loop(0, n_slices, make_step, plane_wave) | ||
|
||
# return ( | ||
# exit_wave | ||
# if isinstance(potential, AbstractAtomicPotential) | ||
# else self._postprocess_exit_wave_for_voxel_potential( | ||
# exit_wave, potential, instrument_config | ||
# ) | ||
# ) | ||
|
||
return exit_wave | ||
|
||
def _postprocess_exit_wave_for_voxel_potential( | ||
self, | ||
exit_wave: Complex[Array, "_ _"], | ||
potential, | ||
instrument_config: InstrumentConfig, | ||
) -> Complex[ | ||
Array, "{instrument_config.padded_y_dim} {instrument_config.padded_x_dim}" | ||
]: | ||
# Check exit wave is at correct pixel size | ||
exit_wave = error_if( | ||
exit_wave, | ||
~jnp.isclose(potential.voxel_size, instrument_config.pixel_size), | ||
f"Tried to use {type(self).__name__} with `{type(potential).__name__}." | ||
"voxel_size != InstrumentConfig.pixel_size`. If this is true, then " | ||
f"`{type(self).__name__}.pixel_rescaling_method` must not be set to " | ||
f"`None`. Try setting `{type(self).__name__}.pixel_rescaling_method = " | ||
"'bicubic'`.", | ||
) | ||
# Resize the image to match the InstrumentConfig.padded_shape | ||
if instrument_config.padded_shape != exit_wave.shape: | ||
exit_wave = instrument_config.crop_or_pad_to_padded_shape( | ||
exit_wave, constant_values=1.0 + 0.0j | ||
) | ||
|
||
return exit_wave | ||
|
||
|
||
# def _interpolate_voxel_grid_to_rotated_coordinates( | ||
# real_voxel_grid, | ||
# coordinate_grid_in_pixels, | ||
# interpolation_order, | ||
# ): | ||
# # Convert to logical coordinates | ||
# z_dim, y_dim, x_dim = real_voxel_grid.shape | ||
# logical_coordinate_grid = ( | ||
# coordinate_grid_in_pixels | ||
# + jnp.asarray((x_dim // 2, y_dim // 2, z_dim // 2))[None, None, None, :] | ||
# ) | ||
# # Convert arguments to map_coordinates convention and compute | ||
# x, y, z = jnp.transpose(logical_coordinate_grid, axes=[3, 0, 1, 2]) | ||
# return map_coordinates( | ||
# real_voxel_grid, (z, y, x), order=interpolation_order, mode="fill", cval=0.0 | ||
# ) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.