Skip to content

Commit

Permalink
Merge pull request #207 from mjo22/grid-search
Browse files Browse the repository at this point in the history
Add a grid search API in cryojax.inference
  • Loading branch information
mjo22 authored May 7, 2024
2 parents d9091c1 + 91233d9 commit 8062e86
Show file tree
Hide file tree
Showing 110 changed files with 4,774 additions and 3,170 deletions.
77 changes: 37 additions & 40 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,77 +49,74 @@ The [`jax-finufft`](https://github.com/dfm/jax-finufft) package is an optional d

The following is a basic workflow to simulate an image.

First, instantiate the scattering potential representation and its respective method for computing image projections.
First, instantiate the spatial potential energy distribution representation and its respective method for computing image projections.

```python
import jax
import jax.numpy as jnp
import cryojax.simulator as cs
from cryojax.io import read_array_with_spacing_from_mrc
import cryojax.simulator as cxs
from cryojax.data import read_array_with_spacing_from_mrc

# Instantiate the scattering potential.
# Instantiate the scattering potential
filename = "example_scattering_potential.mrc"
real_voxel_grid, voxel_size = read_array_with_spacing_from_mrc(filename)
potential = cs.FourierVoxelGridPotential.from_real_voxel_grid(real_voxel_grid, voxel_size)
# ... now instantiate fourier slice extraction
integrator = cs.FourierSliceExtract(interpolation_order=1)
```

Here, the 3D scattering potential array is read from `filename`. Then, the abstraction of the scattering potential is then loaded in fourier-space into a `FourierVoxelGridPotential`, and the fourier-slice projection theorem is initialized with `FourierSliceExtract`. The scattering potential can be generated with an external program, such as the [cisTEM](https://github.com/timothygrant80/cisTEM) simulate tool.

We can now instantiate the representation of a biological specimen, which also includes a pose.

```python
# First instantiate the pose. Here, angles are given in degrees
pose = cs.EulerAnglePose(
potential = cxs.FourierVoxelGridPotential.from_real_voxel_grid(real_voxel_grid, voxel_size)
# ... now, instantiate the pose. Angles are given in degrees
pose = cxs.EulerAnglePose(
offset_x_in_angstroms=5.0,
offset_y_in_angstroms=-3.0,
view_phi=20.0,
view_theta=80.0,
view_psi=-10.0,
)
# ... now, build the biological specimen
specimen = cs.Specimen(potential, integrator, pose)
# ... now, build the ensemble. In this case, the ensemble is just a single structure
structural_ensemble = cxs.SingleStructureEnsemble(potential, pose)
```

Next, build the model for the electron microscope. Here, we simply include a model for the CTF in the weak-phase approximation (linear image formation theory).
Here, the 3D scattering potential array is read from `filename`. Then, the abstraction of the scattering potential is then loaded in fourier-space into a `FourierVoxelGridPotential`. The scattering potential can be generated with an external program, such as the [cisTEM](https://github.com/timothygrant80/cisTEM) simulate tool. Then, the representation of a biological specimen is instantiated, which also includes a pose and conformational heterogeneity. Here, the `SingleStructureEnsemble` class takes a pose but has no heterogeneity.

Next, build the *scattering theory*. The simplest `scattering_theory` is the `LinearScatteringTheory`. This represents the usual image formation pipeline in cryo-EM, which forms images by computing projections of the potential and convolving the result with a contrast transfer function.

```python
from cryojax.image import operators as op

# First, initialize the CTF and its optics model
ctf = cs.CTF(
defocus_u_in_angstroms=10000.0,
defocus_v_in_angstroms=9800.0,
# Initialize the scattering theory. First, instantiate fourier slice extraction
potential_integrator = cxs.FourierSliceExtraction(interpolation_order=1)
# ... next, the contrast transfer theory
ctf = cxs.ContrastTransferFunction(
defocus_in_angstroms=9800.0,
astigmatism_in_angstroms=200.0,
astigmatism_angle=10.0,
amplitude_contrast_ratio=0.1)
optics = cs.WeakPhaseOptics(ctf, envelope=op.FourierGaussian(b_factor=5.0)) # b_factor is given in Angstroms^2
# ... these are stored in the Instrument
voltage_in_kilovolts = 300.0,
instrument = cs.Instrument(voltage_in_kilovolts, optics)
amplitude_contrast_ratio=0.1
)
transfer_theory = cxs.ContrastTransferTheory(ctf, envelope=op.FourierGaussian(b_factor=5.0))
# ... now for the scattering theory
scattering_theory = cxs.LinearScatteringTheory(structural_ensemble, potential_integrator, transfer_theory)
```

The `CTF` has parameters used in CTFFIND4, which take their default values if not
explicitly configured here. Finally, we can instantiate the `ImagePipeline` and simulate an image.
The `ContrastTransferFunction` has parameters used in CTFFIND4, which take their default values if not
explicitly configured here. Finally, we can instantiate the `imaging_pipeline`--the highest level of imaging abstraction in `cryojax`--and simulate an image. Here, we choose a `ContrastImagingPipeline`, which simulates image contrast from a linear scattering theory.

```python
# Instantiate the image configuration
config = cs.ImageConfig(shape=(320, 320), pixel_size=voxel_size)
# Build the image formation model
pipeline = cs.ImagePipeline(config, specimen, instrument)
# ... simulate an image and return in real-space.
image_without_noise = pipeline.render(get_real=True)
# Finally, build the image formation model
# ... first instantiate the instrument configuration
instrument_config = cxs.InstrumentConfig(shape=(320, 320), pixel_size=voxel_size, voltage_in_kilovolts=300.0)
# ... now the imaging pipeline
imaging_pipeline = cxs.ContrastImagingPipeline(instrument_config, scattering_theory)
# ... finally, simulate an image and return in real-space!
image_without_noise = imaging_pipeline.render(get_real=True)
```

`cryojax` also defines a library of distributions from which to sample the data. These distributions define the stochastic model from which images are drawn. For example, instantiate an `IndependentFourierGaussian` distribution and either sample from it or compute its log-likelihood.
`cryojax` also defines a library of distributions from which to sample the data. These distributions define the stochastic model from which images are drawn. For example, instantiate an `IndependentGaussianFourierModes` distribution and either sample from it or compute its log-likelihood.

```python
from cryojax.image import rfftn
from cryojax.image import rfftn, operators as op
from cryojax.inference import distributions as dist
from cryojax.image import operators as op

# Passing the ImagePipeline and a variance function, instantiate the distribution
distribution = dist.IndependentFourierGaussian(pipeline, variance=op.Constant(1.0))
distribution = dist.IndependentGaussianFourierModes(
imaging_pipeline, variance_function=op.Constant(1.0)
)
# ... then, either simulate an image from this distribution
key = jax.random.PRNGKey(seed=0)
image_with_noise = distribution.sample(key)
Expand Down
6 changes: 3 additions & 3 deletions docs/api/simulator/scattering_potential.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Scattering potential representations

`cryojax` provides different options for how to represent scattering potentials in cryo-EM.
`cryojax` provides different options for how to represent spatial potential energy distributions in cryo-EM.

???+ abstract "`cryojax.simulator.AbstractScatteringPotential`"
::: cryojax.simulator.AbstractScatteringPotential
???+ abstract "`cryojax.simulator.AbstractPotentialRepresentation`"
::: cryojax.simulator.AbstractPotentialRepresentation
options:
members:
- rotate_to_pose
Expand Down
610 changes: 610 additions & 0 deletions docs/examples/cross-correlation-search.ipynb

Large diffs are not rendered by default.

Git LFS file not shown
43 changes: 19 additions & 24 deletions docs/examples/read-dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
" RELION STAR file for a particle stack. Upon accessing an image in the particle stack, a `RelionParticleStack`\n",
" is returned. Specifically, the `RelionParticleStack` stores the image(s) in the image stack, as well as the metadata\n",
" in the STAR file. The metadata is used to instantiate compatible `cryojax` objects. For example, the `RelionParticleStack`\n",
" stores a `cryojax` models for the contrast transfer function (the `CTF` class) and the pose (the `EulerAnglePose` class).\n",
" stores a `cryojax` models for the contrast transfer function (the `ContrastTransferFunction` class) and the pose (the `EulerAnglePose` class).\n",
"\n",
" More generally, a `RelionDataset` is an `AbstractDataset`, which is complemented by the abstraction of a particle stack: the `AbstractParticleStack`.\n",
" These abstract interfaces are part of the `cryojax` public API! "
Expand All @@ -80,16 +80,13 @@
"text": [
"RelionParticleStack(\n",
" image_stack=f32[100,100],\n",
" config=ImageConfig(\n",
" instrument_config=InstrumentConfig(\n",
" shape=(100, 100),\n",
" pixel_size=f32[],\n",
" voltage_in_kilovolts=f32[],\n",
" electrons_per_angstrom_squared=f32[],\n",
" padded_shape=(100, 100),\n",
" pad_mode='constant',\n",
" rescale_method='bicubic',\n",
" wrapped_frequency_grid=FrequencyGrid(array=f32[100,51,2]),\n",
" wrapped_padded_frequency_grid=FrequencyGrid(array=f32[100,51,2]),\n",
" wrapped_coordinate_grid=CoordinateGrid(array=f32[100,100,2]),\n",
" wrapped_padded_coordinate_grid=CoordinateGrid(array=f32[100,100,2])\n",
" pad_mode='constant'\n",
" ),\n",
" pose=EulerAnglePose(\n",
" offset_x_in_angstroms=f32[],\n",
Expand All @@ -99,11 +96,11 @@
" view_theta=f32[],\n",
" view_psi=f32[]\n",
" ),\n",
" ctf=CTF(\n",
" defocus_u_in_angstroms=f32[],\n",
" defocus_v_in_angstroms=f32[],\n",
" ctf=ContrastTransferFunction(\n",
" defocus_in_angstroms=f32[],\n",
" astigmatism_in_angstroms=f32[],\n",
" astigmatism_angle=f32[],\n",
" voltage_in_kilovolts=f32[],\n",
" voltage_in_kilovolts=300.0,\n",
" spherical_aberration_in_mm=f32[],\n",
" amplitude_contrast_ratio=f32[],\n",
" phase_shift=f32[]\n",
Expand Down Expand Up @@ -201,7 +198,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we see that the `image_stack` attribute has a leading dimension for each image. We can also inspect the metadata read from the STAR file by printing the `CTF`."
"Now, we see that the `image_stack` attribute has a leading dimension for each image. We can also inspect the metadata read from the STAR file by printing the `ContrastTransferFunction`."
]
},
{
Expand All @@ -213,11 +210,11 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CTF(\n",
" defocus_u_in_angstroms=Array([10050.97, 10050.97, 10050.97], dtype=float32),\n",
" defocus_v_in_angstroms=Array([9999.999, 9999.999, 9999.999], dtype=float32),\n",
"ContrastTransferFunction(\n",
" defocus_in_angstroms=Array([10050.97, 10050.97, 10050.97], dtype=float32),\n",
" astigmatism_in_angstroms=Array([-50.970703, -50.970703, -50.970703], dtype=float32),\n",
" astigmatism_angle=Array([-54.58706, -54.58706, -54.58706], dtype=float32),\n",
" voltage_in_kilovolts=Array(300., dtype=float32),\n",
" voltage_in_kilovolts=300.0,\n",
" spherical_aberration_in_mm=Array(2.7, dtype=float32),\n",
" amplitude_contrast_ratio=Array(0.1, dtype=float32),\n",
" phase_shift=Array([0., 0., 0.], dtype=float32)\n",
Expand All @@ -234,7 +231,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice that not all attributes of the `CTF` have a leading dimension. For those familiar with RELION STAR format, only the `CTF` parameters stored on a per-particle basis have a leading dimension! Parameters stored in the opticsGroup do not have a leading dimension."
"Notice that not all attributes of the `ContrastTransferFunction` have a leading dimension. For those familiar with RELION STAR format, only the `ContrastTransferFunction` parameters stored on a per-particle basis have a leading dimension! Parameters stored in the opticsGroup do not have a leading dimension."
]
},
{
Expand Down Expand Up @@ -294,14 +291,12 @@
"# ... and the image in fourier space\n",
"fourier_image = rfftn(relion_particle.image_stack)\n",
"# ... and the cartesian coordinate system\n",
"pixel_size = relion_particle.config.pixel_size\n",
"pixel_size = relion_particle.instrument_config.pixel_size\n",
"frequency_grid_in_angstroms = (\n",
" relion_particle.config.wrapped_frequency_grid_in_angstroms.get()\n",
" relion_particle.instrument_config.wrapped_frequency_grid_in_angstroms.get()\n",
")\n",
"# ... now, compute a radial coordinate system\n",
"radial_frequency_grid_in_angstroms = jnp.linalg.norm(\n",
" frequency_grid_in_angstroms, axis=-1\n",
")\n",
"radial_frequency_grid_in_angstroms = jnp.linalg.norm(frequency_grid_in_angstroms, axis=-1)\n",
"# ... plot the image in fourier space and the radial frequency grid\n",
"fig, axes = plt.subplots(figsize=(5, 4), ncols=2)\n",
"plot_image(\n",
Expand Down Expand Up @@ -361,7 +356,7 @@
"\n",
"\n",
"fig, ax = plt.subplots(figsize=(4, 4))\n",
"N_pixels = math.prod(relion_particle.config.shape)\n",
"N_pixels = math.prod(relion_particle.instrument_config.shape)\n",
"spectrum, wavenumbers = powerspectrum(\n",
" fourier_image,\n",
" radial_frequency_grid_in_angstroms,\n",
Expand Down
Loading

0 comments on commit 8062e86

Please sign in to comment.