Skip to content

Commit

Permalink
Merge pull request #192 from mjo22/dynamic-type-hints
Browse files Browse the repository at this point in the history
Add dynamic type hints for image shapes at run-time
  • Loading branch information
mjo22 authored Mar 25, 2024
2 parents 0c50bb9 + 410f833 commit 53258cf
Show file tree
Hide file tree
Showing 17 changed files with 240 additions and 161 deletions.
10 changes: 5 additions & 5 deletions src/cryojax/inference/distributions/_gaussian_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def __init__(
def render(
self, *, get_real: bool = True
) -> (
Float[Array, "{self.pipeline.config.shape[0]} {self.pipeline.config.shape[1]}"]
| Complex[Array, "{self.pipeline.config.shape[0]} {self.config.shape[1]//2+1}"]
Float[Array, "{self.pipeline.config.y_dim} {self.pipeline.config.x_dim}"]
| Complex[Array, "{self.pipeline.config.y_dim} {self.config.x_dim//2+1}"]
):
"""Render the image formation model."""
return self.contrast_scale * self.pipeline.render(
Expand All @@ -64,8 +64,8 @@ def render(
def sample(
self, key: PRNGKeyArray, *, get_real: bool = True
) -> (
Float[Array, "{self.pipeline.config.shape[0]} {self.pipeline.config.shape[1]}"]
| Complex[Array, "{self.pipeline.config.shape[0]} {self.config.shape[1]//2+1}"]
Float[Array, "{self.pipeline.config.y_dim} {self.pipeline.config.x_dim}"]
| Complex[Array, "{self.pipeline.config.y_dim} {self.config.x_dim//2+1}"]
):
"""Sample from the gaussian noise model."""
N_pix = np.prod(self.pipeline.config.padded_shape)
Expand All @@ -85,7 +85,7 @@ def log_likelihood(
self,
observed: Complex[
Array,
"{self.pipeline.config.shape[0]} {self.pipeline.config.shape[1]//2+1}",
"{self.pipeline.config.y_dim} {self.pipeline.config.x_dim//2+1}",
],
) -> RealNumber:
"""Evaluate the log-likelihood of the gaussian noise model.
Expand Down
8 changes: 4 additions & 4 deletions src/cryojax/simulator/_assembly/_assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import equinox as eqx
import jax
from equinox import AbstractVar
from jaxtyping import Array, Float, Shaped
from jaxtyping import Array, Float

from ...rotations import SO3
from .._conformation import AbstractConformation
Expand Down Expand Up @@ -59,18 +59,18 @@ def __check_init__(self):

@cached_property
@abstractmethod
def offsets_in_angstroms(self) -> Float[Array, "n_subunits 3"]:
def offsets_in_angstroms(self) -> Float[Array, "{n_subunits} 3"]:
"""The positions of each subunit."""
raise NotImplementedError

@cached_property
@abstractmethod
def rotations(self) -> Shaped[SO3, " n_subunits"]:
def rotations(self) -> SO3:
"""The relative rotations between subunits."""
raise NotImplementedError

@cached_property
def poses(self) -> Shaped[AbstractPose, " n_subunits"]:
def poses(self) -> AbstractPose:
"""
Draw the poses of the subunits in the lab frame, measured
from the rotation relative to the first subunit.
Expand Down
8 changes: 4 additions & 4 deletions src/cryojax/simulator/_assembly/_helix.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __check_init__(self):
)

@cached_property
def offsets_in_angstroms(self) -> Float[Array, "n_subunits 3"]:
def offsets_in_angstroms(self) -> Float[Array, "{self.n_subunits} 3"]:
"""Get the helical lattice positions in the center of mass frame."""
return compute_helical_lattice_positions(
self.rise,
Expand All @@ -98,7 +98,7 @@ def offsets_in_angstroms(self) -> Float[Array, "n_subunits 3"]:
)

@cached_property
def rotations(self) -> Shaped[SO3, "n_subunits"]:
def rotations(self) -> SO3:
"""Get the helical lattice rotations in the center of mass frame.
These are rotations of the initial subunit.
Expand All @@ -122,7 +122,7 @@ def compute_helical_lattice_positions(
n_subunits_per_start: int,
initial_displacement: Float[Array, "3"],
n_start: int = 1,
) -> Float[Array, "n_start*n_subunits_per_start 3"]:
) -> Float[Array, "{n_start*n_subunits_per_start} 3"]:
"""
Compute the lattice points of a helix for a given
rise, twist, radius, and start number.
Expand Down Expand Up @@ -208,7 +208,7 @@ def compute_helical_lattice_rotations(
n_subunits_per_start: int,
initial_rotation: Float[Array, "3 3"] = jnp.eye(3),
n_start: int = 1,
) -> Float[Array, "n_start*n_subunits_per_start 3"]:
) -> Float[Array, "{n_start*n_subunits_per_start} 3 3"]:
"""
Compute the relative rotations of subunits on a
helical lattice, parameterized by the
Expand Down
37 changes: 31 additions & 6 deletions src/cryojax/simulator/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
The image configuration and utility manager.
"""

import math
from functools import cached_property
from typing import Any, Callable, Optional, Union

Expand Down Expand Up @@ -129,12 +130,12 @@ def wrapped_padded_frequency_grid_in_angstroms(self) -> FrequencyGrid:
def rescale_to_pixel_size(
self,
real_or_fourier_image: (
Float[Array, "{self.padded_shape[0]} {self.padded_shape[1]}"]
| Complex[Array, "{self.padded_shape[0]} {self.padded_shape[1]//2+1}"]
Float[Array, "{self.padded_y_dim} {self.padded_x_dim}"]
| Complex[Array, "{self.padded_y_dim} {self.padded_x_dim//2+1}"]
),
current_pixel_size: RealNumber,
is_real: bool = True,
) -> Complex[Array, "{self.padded_shape[0]} {self.padded_shape[1]//2+1}"]:
) -> Complex[Array, "{self.padded_y_dim} {self.padded_x_dim//2+1}"]:
"""Rescale the image pixel size using real-space interpolation. Only
interpolate if the `pixel_size` is not the `current_pixel_size`."""
if is_real:
Expand All @@ -160,20 +161,44 @@ def rescale_to_pixel_size(

def crop_to_shape(
self, image: RealImage
) -> Float[Array, "{self.shape[0]} {self.shape[1]}"]:
) -> Float[Array, "{self.y_dim} {self.x_dim}"]:
"""Crop an image."""
return crop_to_shape(image, self.shape)

def pad_to_padded_shape(
self, image: RealImage, **kwargs: Any
) -> Float[Array, "{self.padded_shape[0]} {self.padded_shape[1]}"]:
) -> Float[Array, "{self.padded_y_dim} {self.padded_x_dim}"]:
"""Pad an image."""
return pad_to_shape(image, self.padded_shape, mode=self.pad_mode, **kwargs)

def crop_or_pad_to_padded_shape(
self, image: RealImage, **kwargs: Any
) -> Float[Array, "{self.padded_shape[0]} {self.padded_shape[1]}"]:
) -> Float[Array, "{self.padded_y_dim} {self.padded_x_dim}"]:
"""Reshape an image using cropping or padding."""
return resize_with_crop_or_pad(
image, self.padded_shape, mode=self.pad_mode, **kwargs
)

@property
def n_pix(self) -> int:
return math.prod(self.shape)

@property
def y_dim(self) -> int:
return self.shape[0]

@property
def x_dim(self) -> int:
return self.shape[1]

@property
def padded_y_dim(self) -> int:
return self.padded_shape[0]

@property
def padded_x_dim(self) -> int:
return self.padded_shape[1]

@property
def padded_n_pix(self) -> int:
return math.prod(self.padded_shape)
8 changes: 5 additions & 3 deletions src/cryojax/simulator/_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import jax.random as jr
import numpy as np
from equinox import AbstractVar, field, Module
from jaxtyping import PRNGKeyArray, Shaped
from jaxtyping import Array, Complex, PRNGKeyArray, Shaped

from ..core import error_if_not_fractional
from ..image import irfftn, rfftn
Expand Down Expand Up @@ -106,11 +106,13 @@ def sample(

def __call__(
self,
fourier_squared_wavefunction_at_detector_plane: ComplexImage,
fourier_squared_wavefunction_at_detector_plane: Complex[
Array, "{config.padded_y_dim} {config.padded_x_dim//2+1}"
],
dose: ElectronDose,
config: ImageConfig,
key: Optional[PRNGKeyArray] = None,
) -> ComplexImage:
) -> Complex[Array, "{config.padded_y_dim} {config.padded_x_dim//2+1}"]:
"""Pass the image through the detector model."""
N_pix = np.prod(config.padded_shape)
frequency_grid = config.wrapped_padded_frequency_grid.get()
Expand Down
33 changes: 22 additions & 11 deletions src/cryojax/simulator/_ice.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,31 @@
import jax.random as jr
import numpy as np
from equinox import Module
from jaxtyping import PRNGKeyArray
from jaxtyping import Array, Complex, PRNGKeyArray

from ..image.operators import FourierOperatorLike
from ..typing import ComplexImage, Image
from ._config import ImageConfig


class AbstractIce(Module, strict=True):
"""Base class for an ice model."""

@abstractmethod
def sample(self, key: PRNGKeyArray, config: ImageConfig) -> Image:
def sample(
self, key: PRNGKeyArray, config: ImageConfig
) -> Complex[Array, "{config.padded_y_dim} {config.padded_x_dim//2+1}"]:
"""Sample a stochastic realization of the potential due to the ice
at the exit plane."""
raise NotImplementedError

def __call__(
self,
key: PRNGKeyArray,
fourier_potential_at_exit_plane: ComplexImage,
fourier_potential_at_exit_plane: Complex[
Array, "{config.padded_y_dim} {config.padded_x_dim//2+1}"
],
config: ImageConfig,
) -> ComplexImage:
) -> Complex[Array, "{config.padded_y_dim} {config.padded_x_dim//2+1}"]:
"""Compute the combined potential of the ice and the specimen."""
# Sample the realization of the potential due to the ice.
fourier_ice_potential_at_exit_plane = self.sample(key, config)
Expand All @@ -42,20 +45,26 @@ class NullIce(AbstractIce):
"""A "null" ice model."""

@override
def sample(self, key: PRNGKeyArray, config: ImageConfig) -> Image:
def sample(
self, key: PRNGKeyArray, config: ImageConfig
) -> Complex[Array, "{config.padded_y_dim} {config.padded_x_dim//2+1}"]:
return jnp.zeros(
config.wrapped_padded_frequency_grid_in_angstroms.get().shape[0:-1]
config.wrapped_padded_frequency_grid_in_angstroms.get().shape[0:-1],
dtype=complex,
)

@override
def __call__(
self,
key: PRNGKeyArray,
potential_at_exit_plane: ComplexImage,
fourier_potential_at_exit_plane: Complex[
Array, "{config.padded_y_dim} {config.padded_x_dim//2+1}"
],
config: ImageConfig,
) -> ComplexImage:
) -> Complex[Array, "{config.padded_y_dim} {config.padded_x_dim//2+1}"]:
return jnp.zeros(
config.wrapped_padded_frequency_grid_in_angstroms.get().shape[0:-1]
config.wrapped_padded_frequency_grid_in_angstroms.get().shape[0:-1],
dtype=complex,
)


Expand All @@ -76,7 +85,9 @@ def __init__(self, variance: FourierOperatorLike):
self.variance = variance

@override
def sample(self, key: PRNGKeyArray, config: ImageConfig) -> ComplexImage:
def sample(
self, key: PRNGKeyArray, config: ImageConfig
) -> Complex[Array, "{config.padded_y_dim} {config.padded_x_dim//2+1}"]:
"""Sample a realization of the ice potential as colored gaussian noise."""
N_pix = np.prod(config.padded_shape)
frequency_grid_in_angstroms = (
Expand Down
32 changes: 22 additions & 10 deletions src/cryojax/simulator/_instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

import jax.numpy as jnp
from equinox import Module
from jaxtyping import PRNGKeyArray
from jaxtyping import Array, Complex, PRNGKeyArray

from ..image import ifftn, rfftn
from ..typing import ComplexImage, Image, RealNumber
from ..typing import RealNumber
from ._config import ImageConfig
from ._detector import AbstractDetector, NullDetector
from ._dose import ElectronDose
Expand Down Expand Up @@ -63,10 +63,15 @@ def __init__(

def propagate_to_detector_plane(
self,
fourier_potential_at_exit_plane: ComplexImage,
fourier_potential_at_exit_plane: Complex[
Array, "{config.padded_y_dim} {config.padded_x_dim//2+1}"
],
config: ImageConfig,
defocus_offset: RealNumber | float = 0.0,
) -> Image:
) -> (
Complex[Array, "{config.padded_y_dim} {config.padded_x_dim//2+1}"]
| Complex[Array, "{config.padded_y_dim} {config.padded_x_dim}"]
):
"""Propagate the scattering potential with the optics model."""
fourier_contrast_or_wavefunction_at_detector_plane = self.optics(
fourier_potential_at_exit_plane, config, defocus_offset=defocus_offset
Expand All @@ -76,9 +81,12 @@ def propagate_to_detector_plane(

def compute_fourier_squared_wavefunction(
self,
fourier_contrast_or_wavefunction_at_detector_plane: ComplexImage,
fourier_contrast_or_wavefunction_at_detector_plane: (
Complex[Array, "{config.padded_y_dim} {config.padded_x_dim//2+1}"]
| Complex[Array, "{config.padded_y_dim} {config.padded_x_dim}"]
),
config: ImageConfig,
) -> ComplexImage:
) -> Complex[Array, "{config.padded_y_dim} {config.padded_x_dim//2+1}"]:
"""Compute the squared wavefunction at the detector plane, given either the
contrast or the wavefunction.
"""
Expand Down Expand Up @@ -114,9 +122,11 @@ def compute_fourier_squared_wavefunction(

def compute_expected_electron_events(
self,
fourier_squared_wavefunction_at_detector_plane: ComplexImage,
fourier_squared_wavefunction_at_detector_plane: Complex[
Array, "{config.padded_y_dim} {config.padded_x_dim//2+1}"
],
config: ImageConfig,
) -> ComplexImage:
) -> Complex[Array, "{config.padded_y_dim} {config.padded_x_dim//2+1}"]:
"""Compute the expected electron events from the detector."""
fourier_expected_electron_events = self.detector(
fourier_squared_wavefunction_at_detector_plane, self.dose, config, key=None
Expand All @@ -127,9 +137,11 @@ def compute_expected_electron_events(
def measure_detector_readout(
self,
key: PRNGKeyArray,
fourier_squared_wavefunction_at_detector_plane: ComplexImage,
fourier_squared_wavefunction_at_detector_plane: Complex[
Array, "{config.padded_y_dim} {config.padded_x_dim//2+1}"
],
config: ImageConfig,
) -> ComplexImage:
) -> Complex[Array, "{config.padded_y_dim} {config.padded_x_dim//2+1}"]:
"""Measure the readout from the detector."""
fourier_detector_readout = self.detector(
fourier_squared_wavefunction_at_detector_plane, self.dose, config, key
Expand Down
3 changes: 1 addition & 2 deletions src/cryojax/simulator/_integrators/_fourier_slice_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
map_coordinates_with_cubic_spline,
rfftn,
)
from ...typing import ComplexImage
from .._config import ImageConfig
from .._potential import (
FourierVoxelGridPotential,
Expand Down Expand Up @@ -52,7 +51,7 @@ def __call__(
self,
potential: FourierVoxelGridPotential | FourierVoxelGridPotentialInterpolator,
config: ImageConfig,
) -> ComplexImage:
) -> Complex[Array, "{config.padded_y_dim} {config.padded_x_dim//2+1}"]:
"""Compute a projection of the real-space potential by extracting
a central slice in fourier space.
"""
Expand Down
6 changes: 3 additions & 3 deletions src/cryojax/simulator/_integrators/_nufft_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

import jax.numpy as jnp
from equinox import field
from jaxtyping import Array, Complex

from ...typing import (
ComplexImage,
PointCloudCoords2D,
PointCloudCoords3D,
RealPointCloud,
Expand All @@ -35,7 +35,7 @@ def __call__(
self,
potential: RealVoxelGridPotential | RealVoxelCloudPotential,
config: ImageConfig,
) -> ComplexImage:
) -> Complex[Array, "{config.padded_y_dim} {config.padded_x_dim//2+1}"]:
"""Rasterize image with non-uniform FFTs."""
if isinstance(potential, RealVoxelGridPotential):
shape = potential.shape
Expand Down Expand Up @@ -67,7 +67,7 @@ def project_with_nufft(
coordinate_list: Union[PointCloudCoords2D, PointCloudCoords3D],
shape: tuple[int, int],
eps: float = 1e-6,
) -> ComplexImage:
) -> Complex[Array, "{shape[0]} {shape[1]}"]:
"""
Project and interpolate 3D volume point cloud
onto imaging plane using a non-uniform FFT.
Expand Down
Loading

0 comments on commit 53258cf

Please sign in to comment.