Skip to content

Commit

Permalink
Merge pull request #214 from mjo22/api-update-naming
Browse files Browse the repository at this point in the history
Settle on naming choices in new API
  • Loading branch information
mjo22 authored May 5, 2024
2 parents 01e0afe + 4186351 commit bdc037b
Show file tree
Hide file tree
Showing 60 changed files with 1,521 additions and 1,323 deletions.
24 changes: 12 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,30 +81,30 @@ Next, build the *scattering theory*. The simplest `scattering_theory` is the `Li
from cryojax.image import operators as op

# Initialize the scattering theory. First, instantiate fourier slice extraction
projection_method = cxs.FourierSliceExtract(interpolation_order=1)
potential_integrator = cxs.FourierSliceExtraction(interpolation_order=1)
# ... next, the contrast transfer theory
transfer_function = cxs.AberratedCTF(
defocus_u_in_angstroms=10000.0,
defocus_v_in_angstroms=9800.0,
ctf = cxs.ContrastTransferFunction(
defocus_in_angstroms=9800.0,
astigmatism_in_angstroms=200.0,
astigmatism_angle=10.0,
amplitude_contrast_ratio=0.1
)
transfer_theory = cxs.ContrastTransferTheory(transfer_function, envelope=op.FourierGaussian(b_factor=5.0))
transfer_theory = cxs.ContrastTransferTheory(ctf, envelope=op.FourierGaussian(b_factor=5.0))
# ... now for the scattering theory
scattering_theory = cxs.LinearScatteringTheory(structural_ensemble, projection_method, transfer_theory)
scattering_theory = cxs.LinearScatteringTheory(structural_ensemble, potential_integrator, transfer_theory)
```

The `AberratedCTF` has parameters used in CTFFIND4, which take their default values if not
explicitly configured here. Finally, we can instantiate the `pipeline`--the highest level of imaging abstraction in `cryojax`--and simulate an image. Here, we choose a `ContrastImagingPipeline`, which simulates image contrast from a linear scattering theory.
The `ContrastTransferFunction` has parameters used in CTFFIND4, which take their default values if not
explicitly configured here. Finally, we can instantiate the `imaging_pipeline`--the highest level of imaging abstraction in `cryojax`--and simulate an image. Here, we choose a `ContrastImagingPipeline`, which simulates image contrast from a linear scattering theory.

```python
# Finally, build the image formation model
# ... first instantiate the instrument configuration
config = cxs.InstrumentConfig(shape=(320, 320), pixel_size=voxel_size, voltage_in_kilovolts=300.0)
instrument_config = cxs.InstrumentConfig(shape=(320, 320), pixel_size=voxel_size, voltage_in_kilovolts=300.0)
# ... now the imaging pipeline
pipeline = cxs.ContrastImagingPipeline(config, scattering_theory)
imaging_pipeline = cxs.ContrastImagingPipeline(instrument_config, scattering_theory)
# ... finally, simulate an image and return in real-space!
image_without_noise = pipeline.render(get_real=True)
image_without_noise = imaging_pipeline.render(get_real=True)
```

`cryojax` also defines a library of distributions from which to sample the data. These distributions define the stochastic model from which images are drawn. For example, instantiate an `IndependentGaussianFourierModes` distribution and either sample from it or compute its log-likelihood.
Expand All @@ -114,7 +114,7 @@ from cryojax.image import rfftn, operators as op
from cryojax.inference import distributions as dist

# Passing the ImagePipeline and a variance function, instantiate the distribution
distribution = dist.IndependentGaussianFourierModes(pipeline, variance=op.Constant(1.0))
distribution = dist.IndependentGaussianFourierModes(imaging_pipeline, variance_function=op.Constant(1.0))
# ... then, either simulate an image from this distribution
key = jax.random.PRNGKey(seed=0)
image_with_noise = distribution.sample(key)
Expand Down
46 changes: 22 additions & 24 deletions docs/examples/read-dataset.ipynb

Large diffs are not rendered by default.

108 changes: 64 additions & 44 deletions docs/examples/simulate-image.ipynb

Large diffs are not rendered by default.

60 changes: 31 additions & 29 deletions docs/examples/simulate-micrograph.ipynb

Large diffs are not rendered by default.

22 changes: 11 additions & 11 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,30 +74,30 @@ Next, build the *scattering theory*. The simplest `scattering_theory` is the `Li
from cryojax.image import operators as op

# Initialize the scattering theory. First, instantiate fourier slice extraction
projection_method = cxs.FourierSliceExtract(interpolation_order=1)
potential_integrator = cxs.FourierSliceExtraction(interpolation_order=1)
# ... next, the contrast transfer theory
transfer_function = cxs.AberratedCTF(
defocus_u_in_angstroms=10000.0,
defocus_v_in_angstroms=9800.0,
ctf = cxs.ContrastTransferFunction(
defocus_in_angstroms=9800.0,
astigmatism_in_angstroms=200.0,
astigmatism_angle=10.0,
amplitude_contrast_ratio=0.1
)
transfer_theory = cxs.ContrastTransferTheory(transfer_function, envelope=op.FourierGaussian(b_factor=5.0))
transfer_theory = cxs.ContrastTransferTheory(ctf, envelope=op.FourierGaussian(b_factor=5.0))
# ... now for the scattering theory
scattering_theory = cxs.LinearScatteringTheory(structural_ensemble, projection_method, transfer_theory)
scattering_theory = cxs.LinearScatteringTheory(structural_ensemble, potential_integrator, transfer_theory)
```

The `AberratedCTF` has parameters used in CTFFIND4, which take their default values if not
explicitly configured here. Finally, we can instantiate the `pipeline`--the highest level of imaging abstraction in `cryojax`--and simulate an image. Here, we choose a `ContrastImagingPipeline`, which simulates image contrast from a linear scattering theory.
The `ContrastTransferFunction` has parameters used in CTFFIND4, which take their default values if not
explicitly configured here. Finally, we can instantiate the `imaging_pipeline`--the highest level of imaging abstraction in `cryojax`--and simulate an image. Here, we choose a `ContrastImagingPipeline`, which simulates image contrast from a linear scattering theory.

```python
# Finally, build the image formation model
# ... first instantiate the instrument configuration
config = cxs.InstrumentConfig(shape=(320, 320), pixel_size=voxel_size, voltage_in_kilovolts=300.0)
instrument_config = cxs.InstrumentConfig(shape=(320, 320), pixel_size=voxel_size, voltage_in_kilovolts=300.0)
# ... now the imaging pipeline
pipeline = cxs.ContrastImagingPipeline(config, scattering_theory)
imaging_pipeline = cxs.ContrastImagingPipeline(instrument_config, scattering_theory)
# ... finally, simulate an image and return in real-space!
image_contrast = pipeline.render(get_real=True)
image_without_noise = imaging_pipeline.render(get_real=True)
```

## Next steps
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ version-file = "src/cryojax/cryojax_version.py"
[tool.ruff]
extend-include = ["*.ipynb"]
lint.fixable = ["I001", "F401"]
line-length = 90
lint.ignore = ["E402", "E721", "E731", "E741", "F722"]
lint.ignore-init-module-imports = true
lint.select = ["E", "F", "I001"]
Expand All @@ -57,6 +58,9 @@ extra-standard-library = ["typing_extensions"]
lines-after-imports = 2
order-by-type = false

[tool.black]
line-length = 90

[tool.pyright]
reportIncompatibleMethodOverride = true
reportIncompatibleVariableOverride = false # Incompatible with eqx.AbstractVar
Expand Down
4 changes: 1 addition & 3 deletions src/cryojax/coordinates/_coordinate_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,7 @@ def _make_coordinates_or_frequencies_1d(
) -> Float[Array, " size"]:
"""One-dimensional coordinates in real or fourier space"""
if real_space:
make_1d = (
lambda size, dx: jnp.fft.fftshift(jnp.fft.fftfreq(size, 1 / dx)) * size
)
make_1d = lambda size, dx: jnp.fft.fftshift(jnp.fft.fftfreq(size, 1 / dx)) * size
else:
if rfftfreq is None:
raise ValueError("Argument rfftfreq cannot be None if real_space=False.")
Expand Down
4 changes: 1 addition & 3 deletions src/cryojax/coordinates/_coordinate_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ class CoordinateList(AbstractCoordinates, strict=True):
converter=jnp.asarray
)

def __init__(
self, coordinate_list: Float[Array, "size 2"] | Float[Array, "size 3"]
):
def __init__(self, coordinate_list: Float[Array, "size 2"] | Float[Array, "size 3"]):
self.array = coordinate_list

def get(self) -> Float[Array, "size 3"] | Float[Array, "size 2"]:
Expand Down
68 changes: 30 additions & 38 deletions src/cryojax/data/_relion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pandas as pd
from jaxtyping import Array, Float, Int

from ..simulator import AberratedCTF, EulerAnglePose, InstrumentConfig
from ..simulator import ContrastTransferFunction, EulerAnglePose, InstrumentConfig
from ._dataset import AbstractDataset
from ._io import read_and_validate_starfile
from ._particle_stack import AbstractParticleStack
Expand Down Expand Up @@ -39,28 +39,25 @@ class RelionParticleStack(AbstractParticleStack):
"""

image_stack: Float[Array, "... y_dim x_dim"]
config: InstrumentConfig
instrument_config: InstrumentConfig
pose: EulerAnglePose
transfer_function: AberratedCTF
ctf: ContrastTransferFunction

def __init__(
self,
image_stack: Float[Array, "... y_dim x_dim"],
config: InstrumentConfig,
instrument_config: InstrumentConfig,
pose: EulerAnglePose,
transfer_function: AberratedCTF,
ctf: ContrastTransferFunction,
):
# Set image stack and config as is
self.image_stack = jnp.asarray(image_stack)
self.config = config
self.instrument_config = instrument_config
# Set CTF using the defocus offset in the EulerAnglePose
self.transfer_function = eqx.tree_at(
lambda tf: (tf.defocus_u_in_angstroms, tf.defocus_v_in_angstroms),
transfer_function,
(
transfer_function.defocus_u_in_angstroms + pose.offset_z_in_angstroms,
transfer_function.defocus_v_in_angstroms + pose.offset_z_in_angstroms,
),
self.ctf = eqx.tree_at(
lambda tf: tf.defocus_in_angstroms,
ctf,
ctf.defocus_in_angstroms + pose.offset_z_in_angstroms,
)
# Set defocus offset to zero
self.pose = eqx.tree_at(
Expand All @@ -73,20 +70,19 @@ def __init__(
- `image_stack`: The stack of images. The shape of this array
is a leading batch dimension followed by the shape
of an image in the stack.
- `config`: The instrument configuration. Any subset of pytree leaves may
- `instrument_config`: The instrument configuration. Any subset of pytree leaves may
have a batch dimension.
- `pose`: The pose, represented by euler angles. Any subset of pytree leaves may
have a batch dimension. Upon instantiation, `pose.offset_z_in_angstroms`
is set to zero.
- `transfer_function`: The contrast transfer function. Any subset of pytree leaves may
- `ctf`: The contrast transfer function. Any subset of pytree leaves may
have a batch dimension. Upon instantiation,
`transfer_function.defocus_u_in_angstroms` is set to
`transfer_function.defocus_u_in_angstroms + pose.offset_z_in_angstroms` (and
also for `transfer_function.defocus_v_in_angstroms`).
`ctf.defocus_in_angstroms` is set to
`ctf.defocus_in_angstroms + pose.offset_z_in_angstroms`.
""" # noqa: E501


def _default_make_config_fn(
def _default_make_instrument_config_fn(
shape: tuple[int, int],
pixel_size: Float[Array, ""],
voltage_in_kilovolts: Float[Array, ""],
Expand All @@ -104,7 +100,7 @@ class RelionDataset(AbstractDataset):
path_to_relion_project: pathlib.Path
data_blocks: dict[str, pd.DataFrame]

make_config_fn: Callable[
make_instrument_config_fn: Callable[
[tuple[int, int], Float[Array, "..."], Float[Array, "..."]], InstrumentConfig
]

Expand All @@ -113,10 +109,10 @@ def __init__(
self,
path_to_starfile: str | pathlib.Path,
path_to_relion_project: str | pathlib.Path,
make_config_fn: Callable[
make_instrument_config_fn: Callable[
[tuple[int, int], Float[Array, "..."], Float[Array, "..."]],
InstrumentConfig,
] = _default_make_config_fn,
] = _default_make_instrument_config_fn,
):
"""**Arguments:**
Expand All @@ -129,7 +125,7 @@ def __init__(
object.__setattr__(
self, "path_to_relion_project", pathlib.Path(path_to_relion_project)
)
object.__setattr__(self, "make_config_fn", make_config_fn)
object.__setattr__(self, "make_instrument_config_fn", make_instrument_config_fn)

@final
def __getitem__(
Expand Down Expand Up @@ -216,9 +212,7 @@ def __getitem__(
np.asarray(image_stack_filename, dtype=object)[0],
)
# ... relion convention starts indexing at 1, not 0
particle_index = (
np.asarray(relion_particle_index.astype(int), dtype=int) - 1
)
particle_index = np.asarray(relion_particle_index.astype(int), dtype=int) - 1
else:
raise IOError(
"Could not read `rlnImageName` in STAR file for `RelionDataset` "
Expand All @@ -228,8 +222,10 @@ def __getitem__(
image_stack = np.asarray(mrc.data[particle_index]) # type: ignore
# Read metadata into a RelionParticleStack
# ... particle data
defocus_u_in_angstroms = jnp.asarray(particle_blocks["rlnDefocusU"])
defocus_v_in_angstroms = jnp.asarray(particle_blocks["rlnDefocusV"])
defocus_in_angstroms = jnp.asarray(particle_blocks["rlnDefocusU"])
astigmatism_in_angstroms = jnp.asarray(
particle_blocks["rlnDefocusV"]
) - jnp.asarray(particle_blocks["rlnDefocusU"])
astigmatism_angle = jnp.asarray(particle_blocks["rlnDefocusAngle"])
phase_shift = jnp.asarray(particle_blocks["rlnPhaseShift"])
# ... optics group data
Expand All @@ -239,14 +235,14 @@ def __getitem__(
spherical_aberration_in_mm = jnp.asarray(optics_group["rlnSphericalAberration"])
amplitude_contrast_ratio = jnp.asarray(optics_group["rlnAmplitudeContrast"])
# ... create cryojax objects
config = self.make_config_fn(
instrument_config = self.make_instrument_config_fn(
(int(image_size), int(image_size)),
pixel_size,
jnp.asarray(voltage_in_kilovolts),
)
transfer_function = AberratedCTF(
defocus_u_in_angstroms=defocus_u_in_angstroms,
defocus_v_in_angstroms=defocus_v_in_angstroms,
ctf = ContrastTransferFunction(
defocus_in_angstroms=defocus_in_angstroms,
astigmatism_in_angstroms=astigmatism_in_angstroms,
astigmatism_angle=astigmatism_angle,
voltage_in_kilovolts=voltage_in_kilovolts,
spherical_aberration_in_mm=spherical_aberration_in_mm,
Expand Down Expand Up @@ -302,9 +298,7 @@ def __getitem__(
if particle_blocks["rlnAnglePsi"] == -999.0
else particle_blocks["rlnAnglePsi"]
)
pose_parameter_names_and_values.append(
("view_psi", particle_blocks_for_psi)
)
pose_parameter_names_and_values.append(("view_psi", particle_blocks_for_psi))
elif "rlnAnglePsiPrior" in particle_keys: # support for helices
pose_parameter_names_and_values.append(
("view_psi", particle_blocks["rlnAnglePsiPrior"])
Expand All @@ -321,9 +315,7 @@ def __getitem__(
tuple([jnp.asarray(value) for value in pose_parameter_values]),
)

return RelionParticleStack(
jnp.asarray(image_stack), config, pose, transfer_function
)
return RelionParticleStack(jnp.asarray(image_stack), instrument_config, pose, ctf)

@final
def __len__(self) -> int:
Expand Down
4 changes: 1 addition & 3 deletions src/cryojax/image/_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ def radial_average(
average_as_profile,
).reshape(radial_grid.shape)
else:
raise ValueError(
f"interpolation_mode = {interpolation_mode} not supported."
)
raise ValueError(f"interpolation_mode = {interpolation_mode} not supported.")
return average_as_profile, average_as_grid
else:
return average_as_profile
8 changes: 2 additions & 6 deletions src/cryojax/image/_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ def crop_to_shape(


def crop_to_shape(
image_or_volume: (
Inexact[Array, "y_dim x_dim"] | Inexact[Array, "z_dim y_dim x_dim"]
),
image_or_volume: Inexact[Array, "y_dim x_dim"] | Inexact[Array, "z_dim y_dim x_dim"],
shape: tuple[int, int] | tuple[int, int, int],
) -> (
Inexact[Array, " {shape[0]} {shape[1]}"]
Expand Down Expand Up @@ -124,9 +122,7 @@ def pad_to_shape(


def pad_to_shape(
image_or_volume: (
Inexact[Array, "y_dim x_dim"] | Inexact[Array, "z_dim y_dim x_dim"]
),
image_or_volume: Inexact[Array, "y_dim x_dim"] | Inexact[Array, "z_dim y_dim x_dim"],
shape: tuple[int, int] | tuple[int, int, int],
**kwargs: Any,
) -> (
Expand Down
4 changes: 1 addition & 3 deletions src/cryojax/image/_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,7 @@ def irfftn(
axes: Optional[tuple[int, ...]] = None,
**kwargs: Any,
) -> (
Float[Array, "y_dim x_dim"]
| Float[Array, "z_dim y_dim x_dim"]
| Float[Array, " *s"]
Float[Array, "y_dim x_dim"] | Float[Array, "z_dim y_dim x_dim"] | Float[Array, " *s"]
):
"""
Helper routine to compute the inverse fourier transform
Expand Down
2 changes: 1 addition & 1 deletion src/cryojax/image/_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def compute_mean_and_std_from_fourier_image(
else:
N_modes = N1 * N2
# The mean is just the zero mode divided by the number of modes
mean = fourier_image[0, 0] / N_modes
mean = fourier_image[0, 0].real / N_modes
# The standard deviation is square root norm squared of the zero mean image
std = (
jnp.sqrt(
Expand Down
4 changes: 1 addition & 3 deletions src/cryojax/image/_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def powerspectrum(
k_max: Optional[Float[Array, ""] | float] = None,
) -> (
tuple[Float[Array, " n_bins"], Float[Array, " n_bins"]]
| tuple[
Float[Array, " n_bins"], Float[Array, "y_dim x_dim"], Float[Array, " n_bins"]
]
| tuple[Float[Array, " n_bins"], Float[Array, "y_dim x_dim"], Float[Array, " n_bins"]]
): ...


Expand Down
4 changes: 1 addition & 3 deletions src/cryojax/inference/_grid_search/search_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,7 @@ def batched_body_fun(iteration_index, carry):
)
tree_grid_points = tree_grid_take(
tree_grid,
tree_grid_unravel_index(
raveled_grid_index_batch, tree_grid, is_leaf=is_leaf
),
tree_grid_unravel_index(raveled_grid_index_batch, tree_grid, is_leaf=is_leaf),
)
new_state = method.batch_update(
fn, tree_grid_points, args, state, raveled_grid_index_batch
Expand Down
4 changes: 1 addition & 3 deletions src/cryojax/inference/_grid_search/search_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,7 @@ def postprocess(
if x.ndim > 1
else x.reshape(f_struct.shape)
)
value = jtu.tree_map(
_reshape_fn, tree_grid_take(tree_grid, tree_grid_index)
)
value = jtu.tree_map(_reshape_fn, tree_grid_take(tree_grid, tree_grid_index))
else:
value = None
# ... build and return the solution
Expand Down
4 changes: 1 addition & 3 deletions src/cryojax/inference/_transforms/lie_group_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,7 @@ def __init__(self, quaternion_pose: QuaternionPose):
def get(self) -> Float[Array, "6"]:
"""An implementation of the `jaxlie.manifold.rplus`."""
local_tangent = self.transformed_parameter
group_element = jax.lax.stop_gradient(self.group_element) @ SE3.exp(
local_tangent
)
group_element = jax.lax.stop_gradient(self.group_element) @ SE3.exp(local_tangent)
return QuaternionPose.from_rotation_and_translation(
group_element.rotation, group_element.xyz
)
Loading

0 comments on commit bdc037b

Please sign in to comment.