Skip to content

Commit

Permalink
Merge pull request #200 from mjo22/voltage-refactor
Browse files Browse the repository at this point in the history
Move the accelerating voltage from the `CTF` to the `Instrument`
  • Loading branch information
mjo22 authored Apr 10, 2024
2 parents 35f906e + 6900b03 commit d9091c1
Show file tree
Hide file tree
Showing 53 changed files with 963 additions and 1,006 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ ctf = cs.CTF(
defocus_u_in_angstroms=10000.0,
defocus_v_in_angstroms=9800.0,
astigmatism_angle=10.0,
voltage_in_kilovolts=300.0,
amplitude_contrast_ratio=0.1)
optics = cs.WeakPhaseOptics(ctf, envelope=op.FourierGaussian(b_factor=5.0)) # b_factor is given in Angstroms^2
# ... these are stored in the Instrument
instrument = cs.Instrument(optics)
voltage_in_kilovolts = 300.0,
instrument = cs.Instrument(voltage_in_kilovolts, optics)
```

The `CTF` has parameters used in CTFFIND4, which take their default values if not
Expand Down
38 changes: 21 additions & 17 deletions docs/examples/simulate-image.ipynb

Large diffs are not rendered by default.

51 changes: 31 additions & 20 deletions docs/examples/simulate-micrograph.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ ctf = cs.CTF(
defocus_u_in_angstroms=10000.0,
defocus_v_in_angstroms=9800.0,
astigmatism_angle=10.0,
voltage_in_kilovolts=300.0,
amplitude_contrast_ratio=0.1)
optics = cs.WeakPhaseOptics(ctf, envelope=op.FourierGaussian(b_factor=5.0)) # b_factor is given in Angstroms^2
# ... these are stored in the Instrument
instrument = cs.Instrument(optics)
voltage_in_kilovolts = 300.0
instrument = cs.Instrument(voltage_in_kilovolts, optics)
```

The `CTF` has all parameters used in CTFFIND4, which take their default values if not
Expand Down
1 change: 0 additions & 1 deletion src/cryojax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,5 @@
io as io,
rotations as rotations,
simulator as simulator,
typing as typing,
)
from .cryojax_version import __version__ as __version__
1 change: 1 addition & 0 deletions src/cryojax/constants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
default_form_factor_params as default_form_factor_params,
get_form_factor_params as get_form_factor_params,
)
from ._unit_conversions import convert_keV_to_angstroms as convert_keV_to_angstroms
6 changes: 2 additions & 4 deletions src/cryojax/constants/_load_atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@

import jax
import jax.numpy as jnp
from jaxtyping import Array, Float

from ..typing import IntegerPointCloud
from jaxtyping import Array, Float, Int


def _load_element_form_factor_params():
Expand All @@ -32,7 +30,7 @@ def _load_element_form_factor_params():

@partial(jax.jit, static_argnums=(1,))
def get_form_factor_params(
atom_names: IntegerPointCloud,
atom_names: Int[Array, " size"],
form_factor_params: Optional[Float[Array, "2 N k"]] = None,
):
"""
Expand Down
12 changes: 12 additions & 0 deletions src/cryojax/constants/_unit_conversions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Unit conversions."""

import jax.numpy as jnp
from jaxtyping import Array, Float


def convert_keV_to_angstroms(
energy_in_keV: Float[Array, ""] | float,
) -> Float[Array, ""]:
"""Get the relativistic electron wavelength at a given accelerating voltage."""
energy_in_eV = 1000.0 * energy_in_keV # keV to eV
return jnp.asarray(12.2643 / (energy_in_eV + 0.97845e-6 * energy_in_eV**2) ** 0.5)
45 changes: 25 additions & 20 deletions src/cryojax/coordinates/_coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,7 @@
import jax.numpy as jnp
import numpy as np
from equinox import AbstractVar
from jaxtyping import Array, Float

from ..typing import (
Image,
ImageCoords,
PointCloudCoords2D,
PointCloudCoords3D,
VolumeCoords,
VolumeSliceCoords,
)
from jaxtyping import Array, Float, Inexact


class AbstractCoordinates(eqx.Module, strict=True):
Expand Down Expand Up @@ -66,12 +57,16 @@ class CoordinateList(AbstractCoordinates, strict=True):
A Pytree that wraps a coordinate list.
"""

array: PointCloudCoords3D | PointCloudCoords2D = eqx.field(converter=jnp.asarray)
array: Float[Array, "size 3"] | Float[Array, "size 2"] = eqx.field(
converter=jnp.asarray
)

def __init__(self, coordinate_list: PointCloudCoords2D | PointCloudCoords3D):
def __init__(
self, coordinate_list: Float[Array, "size 2"] | Float[Array, "size 3"]
):
self.array = coordinate_list

def get(self) -> PointCloudCoords3D | PointCloudCoords2D:
def get(self) -> Float[Array, "size 3"] | Float[Array, "size 2"]:
return self.array


Expand All @@ -80,7 +75,9 @@ class CoordinateGrid(AbstractCoordinates, strict=True):
A Pytree that wraps a coordinate grid.
"""

array: ImageCoords | VolumeCoords = eqx.field(converter=jnp.asarray)
array: Float[Array, "y_dim x_dim 2"] | Float[Array, "z_dim y_dim x_dim 3"] = (
eqx.field(converter=jnp.asarray)
)

def __init__(
self,
Expand All @@ -89,7 +86,9 @@ def __init__(
):
self.array = make_coordinates(shape, grid_spacing)

def get(self) -> ImageCoords | VolumeCoords:
def get(
self,
) -> Float[Array, "y_dim x_dim 2"] | Float[Array, "z_dim y_dim x_dim 3"]:
return self.array


Expand All @@ -98,7 +97,9 @@ class FrequencyGrid(AbstractCoordinates, strict=True):
A Pytree that wraps a frequency grid.
"""

array: ImageCoords | VolumeCoords = eqx.field(converter=jnp.asarray)
array: Float[Array, "y_dim x_dim 2"] | Float[Array, "z_dim y_dim x_dim 3"] = (
eqx.field(converter=jnp.asarray)
)

def __init__(
self,
Expand All @@ -108,7 +109,9 @@ def __init__(
):
self.array = make_frequencies(shape, grid_spacing, half_space=half_space)

def get(self) -> ImageCoords | VolumeCoords:
def get(
self,
) -> Float[Array, "y_dim x_dim 2"] | Float[Array, "z_dim y_dim x_dim 3"]:
return self.array


Expand All @@ -120,7 +123,7 @@ class FrequencySlice(AbstractCoordinates, strict=True):
component in the center.
"""

array: VolumeSliceCoords = eqx.field(converter=jnp.asarray)
array: Float[Array, "1 y_dim x_dim 3"] = eqx.field(converter=jnp.asarray)

def __init__(
self,
Expand All @@ -144,7 +147,7 @@ def __init__(
)
self.array = frequency_slice

def get(self) -> VolumeSliceCoords:
def get(self) -> Float[Array, "1 y_dim x_dim 3"]:
return self.array


Expand Down Expand Up @@ -208,7 +211,9 @@ def make_frequencies(
return frequency_grid


def cartesian_to_polar(freqs: ImageCoords, square: bool = False) -> tuple[Image, Image]:
def cartesian_to_polar(
freqs: Float[Array, "y_dim x_dim 2"], square: bool = False
) -> tuple[Inexact[Array, "y_dim x_dim"], Inexact[Array, "y_dim x_dim"]]:
"""
Convert from cartesian to polar coordinates.
Expand Down
1 change: 0 additions & 1 deletion src/cryojax/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from ._dataset import AbstractDataset as AbstractDataset
from ._particle_stack import (
AbstractParticleStack as AbstractParticleStack,
CryojaxParticleStack as CryojaxParticleStack,
)
from ._relion import (
default_relion_make_config as default_relion_make_config,
Expand Down
27 changes: 3 additions & 24 deletions src/cryojax/data/_particle_stack.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
"""Pytrees that represent particle stacks."""

import jax.numpy as jnp
from equinox import AbstractVar, field, Module
from jaxtyping import Shaped

from ..inference.distributions import AbstractDistribution
from ..typing import Image
from equinox import AbstractVar, Module
from jaxtyping import Array, Inexact


class AbstractParticleStack(Module, strict=True):
Expand All @@ -16,21 +12,4 @@ class AbstractParticleStack(Module, strict=True):
image formation model, typically represented with `cryojax` objects.
"""

image_stack: AbstractVar[Shaped[Image, "..."]]


class CryojaxParticleStack(AbstractParticleStack, strict=True):
"""The standard particle stack supported by `cryojax`."""

image_stack: Shaped[Image, "..."] = field(converter=jnp.asarray)
distribution: AbstractDistribution


CryojaxParticleStack.__init__.__doc__ = """**Arguments:**
- `image_stack`: The stack of images. The shape of this array
is a leading batch dimension followed by the shape
of an image in the stack.
- `distribution`: The statistical model from which the data is generated.
Any subset of pytree leaves may have a batch dimension.
"""
image_stack: AbstractVar[Inexact[Array, "... y_dim x_dim"]]
9 changes: 4 additions & 5 deletions src/cryojax/data/_relion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
import mrcfile
import numpy as np
import pandas as pd
from jaxtyping import Float, Int, Shaped
from jaxtyping import Array, Float, Int

from ..io import read_and_validate_starfile
from ..simulator import CTF, EulerAnglePose, ImageConfig
from ..typing import RealImage
from ._dataset import AbstractDataset
from ._particle_stack import AbstractParticleStack

Expand All @@ -39,14 +38,14 @@ class RelionParticleStack(AbstractParticleStack):
[RELION](https://relion.readthedocs.io/en/release-5.0/).
"""

image_stack: Shaped[RealImage, "..."]
image_stack: Float[Array, "... y_dim x_dim"]
config: ImageConfig
pose: EulerAnglePose
ctf: CTF

def __init__(
self,
image_stack: Shaped[RealImage, "..."] | Float[np.ndarray, "... Ny Nx"],
image_stack: Float[Array, "... y_dim x_dim"],
config: ImageConfig,
pose: EulerAnglePose,
ctf: CTF,
Expand Down Expand Up @@ -313,7 +312,7 @@ def __getitem__(
tuple([jnp.asarray(value) for value in pose_parameter_values]),
)

return RelionParticleStack(image_stack, config, pose, ctf)
return RelionParticleStack(jnp.asarray(image_stack), config, pose, ctf)

@final
def __len__(self) -> int:
Expand Down
57 changes: 26 additions & 31 deletions src/cryojax/image/_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,65 +9,60 @@
import jax.numpy as jnp
from jaxtyping import Array, Float, Inexact

from ..typing import (
Image,
RealImage,
RealVolume,
Volume,
)


Vector = Inexact[Array, "N"]
RealVector = Float[Array, "N"]


@overload
def radial_average(
image: Image,
radial_grid: RealImage,
bins: RealVector,
) -> Vector: ...
image: Inexact[Array, "y_dim x_dim"],
radial_grid: Float[Array, "y_dim x_dim"],
bins: Float[Array, " n_bins"],
) -> Inexact[Array, " n_bins"]: ...


@overload
def radial_average(
image: Volume,
radial_grid: RealVolume,
bins: RealVector,
) -> Vector: ...
image: Inexact[Array, "z_dim y_dim x_dim"],
radial_grid: Float[Array, "z_dim y_dim x_dim"],
bins: Float[Array, " n_bins"],
) -> Inexact[Array, " n_bins"]: ...


@overload
def radial_average(
image: Image,
radial_grid: RealImage,
bins: RealVector,
image: Inexact[Array, "y_dim x_dim"],
radial_grid: Float[Array, "y_dim x_dim"],
bins: Float[Array, " n_bins"],
*,
to_grid: bool = False,
interpolation_mode: str = "nearest",
) -> tuple[Vector, Image]: ...
) -> tuple[Inexact[Array, " n_bins"], Inexact[Array, "y_dim x_dim"]]: ...


@overload
def radial_average(
image: Volume,
radial_grid: RealVolume,
bins: RealVector,
image: Inexact[Array, "z_dim y_dim x_dim"],
radial_grid: Float[Array, "z_dim y_dim x_dim"],
bins: Float[Array, " n_bins"],
*,
to_grid: bool = False,
interpolation_mode: str = "nearest",
) -> tuple[Vector, Volume]: ...
) -> tuple[Inexact[Array, " n_bins"], Inexact[Array, "z_dim y_dim x_dim"]]: ...


@partial(jax.jit, static_argnames=["to_grid", "interpolation_mode"])
def radial_average(
image: Image | Volume,
radial_grid: RealImage | RealVolume,
bins: RealVector,
image: Inexact[Array, "y_dim x_dim"] | Inexact[Array, "z_dim y_dim x_dim"],
radial_grid: Float[Array, "y_dim x_dim"] | Float[Array, "z_dim y_dim x_dim"],
bins: Float[Array, " n_bins"],
*,
to_grid: bool = False,
interpolation_mode: str = "nearest",
) -> Vector | tuple[Vector, Image | Volume]:
) -> (
Inexact[Array, " n_bins"]
| tuple[
Inexact[Array, " n_bins"],
Inexact[Array, "y_dim x_dim"] | Inexact[Array, "z_dim y_dim x_dim"],
]
):
"""
Radially average vectors r with a given magnitude
coordinate system |r|.
Expand Down
Loading

0 comments on commit d9091c1

Please sign in to comment.