Skip to content

Commit

Permalink
get rid of . unclear what the difference between and are.
Browse files Browse the repository at this point in the history
  • Loading branch information
mjo22 committed May 6, 2024
1 parent 50e2ef6 commit 91233d9
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 175 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ from cryojax.image import rfftn, operators as op
from cryojax.inference import distributions as dist

# Passing the ImagePipeline and a variance function, instantiate the distribution
distribution = dist.IndependentGaussianFourierModes(imaging_pipeline, variance_function=op.Constant(1.0))
distribution = dist.IndependentGaussianFourierModes(
imaging_pipeline, variance_function=op.Constant(1.0)
)
# ... then, either simulate an image from this distribution
key = jax.random.PRNGKey(seed=0)
image_with_noise = distribution.sample(key)
Expand Down
18 changes: 9 additions & 9 deletions docs/examples/simulate-image.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/cryojax/inference/distributions/_base_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def sample(
raise NotImplementedError

@abstractmethod
def render(self, *, get_real: bool = True) -> Inexact[Array, "y_dim x_dim"]:
def compute_signal(self, *, get_real: bool = True) -> Inexact[Array, "y_dim x_dim"]:
"""Render the image formation model."""
raise NotImplementedError

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
self.signal_scale_factor = error_if_not_positive(jnp.asarray(signal_scale_factor))

@override
def render(
def compute_signal(
self, *, get_real: bool = True
) -> (
Float[
Expand Down Expand Up @@ -116,7 +116,7 @@ def sample(
.astype(complex),
get_real=get_real,
)
image = self.render(get_real=get_real)
image = self.compute_signal(get_real=get_real)
return image + noise

@override
Expand All @@ -140,7 +140,7 @@ def log_likelihood(
# Compute the variance and scale up to be independent of the number of pixels
variance = n_pixels * self.variance_function(freqs)
# Create simulated data
simulated = self.render(get_real=False)
simulated = self.compute_signal(get_real=False)
# Compute residuals
residuals = simulated - observed
# Compute standard normal random variables
Expand Down
214 changes: 57 additions & 157 deletions src/cryojax/simulator/_imaging_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
class AbstractImagingPipeline(Module, strict=True):
"""Base class for an image formation model.
Call an `AbstractImagingPipeline`'s `render` and `sample`,
routines.
Call an `AbstractImagingPipeline`'s `render` routine.
"""

instrument_config: AbstractVar[InstrumentConfig]
Expand All @@ -31,6 +30,7 @@ class AbstractImagingPipeline(Module, strict=True):
@abstractmethod
def render(
self,
rng_key: Optional[PRNGKeyArray] = None,
*,
postprocess: bool = True,
get_real: bool = True,
Expand All @@ -54,6 +54,8 @@ def render(
**Arguments:**
- `rng_key`: The random number generator key. If not passed, render an image
with no stochasticity.
- `postprocess`: If `True`, view the cropped, filtered, and masked image.
If `postprocess = False`, `ImagePipeline.filter`,
`ImagePipeline.mask`, and cropping to `InstrumentConfig.shape`
Expand All @@ -63,41 +65,6 @@ def render(
"""
raise NotImplementedError

@abstractmethod
def sample(
self,
rng_key: PRNGKeyArray,
*,
postprocess: bool = True,
get_real: bool = True,
) -> (
Float[Array, "{self.instrument_config.y_dim} {self.instrument_config.x_dim}"]
| Float[
Array,
"{self.instrument_config.padded_y_dim} "
"{self.instrument_config.padded_x_dim}",
]
| Complex[
Array,
"{self.instrument_config.y_dim} " "{self.instrument_config.x_dim//2+1}",
]
| Complex[
Array,
"{self.instrument_config.padded_y_dim} "
"{self.instrument_config.padded_x_dim//2+1}",
]
):
"""Sample an image from a realization of the stochastic models contained
in the `AbstractImagingPipeline`.
See `ImagePipeline.render` for documentation of keyword arguments.
**Arguments:**
- `rng_key`: The random number generator key.
"""
raise NotImplementedError

def postprocess(
self,
image: Complex[
Expand Down Expand Up @@ -222,37 +189,11 @@ def __init__(

@override
def render(
self, *, postprocess: bool = True, get_real: bool = True
) -> (
Float[Array, "{self.instrument_config.y_dim} {self.instrument_config.x_dim}"]
| Float[
Array,
"{self.instrument_config.padded_y_dim} "
"{self.instrument_config.padded_x_dim}",
]
| Complex[
Array, "{self.instrument_config.y_dim} {self.instrument_config.x_dim//2+1}"
]
| Complex[
Array,
"{self.instrument_config.padded_y_dim} "
"{self.instrument_config.padded_x_dim//2+1}",
]
):
# Compute the squared wavefunction
fourier_contrast_at_detector_plane = (
self.scattering_theory.compute_fourier_contrast_at_detector_plane(
self.instrument_config
)
)

return self._maybe_postprocess(
fourier_contrast_at_detector_plane, postprocess=postprocess, get_real=get_real
)

@override
def sample(
self, rng_key: PRNGKeyArray, *, postprocess: bool = True, get_real: bool = True
self,
rng_key: Optional[PRNGKeyArray] = None,
*,
postprocess: bool = True,
get_real: bool = True,
) -> (
Float[Array, "{self.instrument_config.y_dim} {self.instrument_config.x_dim}"]
| Float[
Expand Down Expand Up @@ -315,41 +256,11 @@ def __init__(

@override
def render(
self, *, postprocess: bool = True, get_real: bool = True
) -> (
Float[Array, "{self.instrument_config.y_dim} {self.instrument_config.x_dim}"]
| Float[
Array,
"{self.instrument_config.padded_y_dim} "
"{self.instrument_config.padded_x_dim}",
]
| Complex[
Array,
"{self.instrument_config.y_dim} {self.instrument_config.x_dim//2+1}",
]
| Complex[
Array,
"{self.instrument_config.padded_y_dim} "
"{self.instrument_config.padded_x_dim//2+1}",
]
):
# Compute the squared wavefunction
theory = self.scattering_theory
fourier_squared_wavefunction_at_detector_plane = (
theory.compute_fourier_squared_wavefunction_at_detector_plane(
self.instrument_config,
)
)

return self._maybe_postprocess(
fourier_squared_wavefunction_at_detector_plane,
postprocess=postprocess,
get_real=get_real,
)

@override
def sample(
self, rng_key: PRNGKeyArray, *, postprocess: bool = True, get_real: bool = True
self,
rng_key: Optional[PRNGKeyArray] = None,
*,
postprocess: bool = True,
get_real: bool = True,
) -> (
Float[Array, "{self.instrument_config.y_dim} {self.instrument_config.x_dim}"]
| Float[
Expand Down Expand Up @@ -419,7 +330,11 @@ def __init__(

@override
def render(
self, *, postprocess: bool = True, get_real: bool = True
self,
rng_key: Optional[PRNGKeyArray] = None,
*,
postprocess: bool = True,
get_real: bool = True,
) -> (
Float[Array, "{self.instrument_config.y_dim} {self.instrument_config.x_dim}"]
| Float[
Expand All @@ -428,67 +343,52 @@ def render(
"{self.instrument_config.padded_x_dim}",
]
| Complex[
Array,
"{self.instrument_config.y_dim} {self.instrument_config.x_dim//2+1}",
Array, "{self.instrument_config.y_dim} {self.instrument_config.x_dim//2+1}"
]
| Complex[
Array,
"{self.instrument_config.padded_y_dim} "
"{self.instrument_config.padded_x_dim//2+1}",
]
):
# Compute the squared wavefunction
theory = self.scattering_theory
fourier_squared_wavefunction_at_detector_plane = (
theory.compute_fourier_squared_wavefunction_at_detector_plane(
self.instrument_config
if rng_key is None:
# Compute the squared wavefunction
theory = self.scattering_theory
fourier_squared_wavefunction_at_detector_plane = (
theory.compute_fourier_squared_wavefunction_at_detector_plane(
self.instrument_config
)
)
# ... now measure the expected electron events at the detector
fourier_expected_electron_events = (
self.detector.compute_expected_electron_events(
fourier_squared_wavefunction_at_detector_plane, self.instrument_config
)
)
)
# ... now measure the expected electron events at the detector
fourier_expected_electron_events = self.detector.compute_expected_electron_events(
fourier_squared_wavefunction_at_detector_plane, self.instrument_config
)

return self._maybe_postprocess(
fourier_expected_electron_events, postprocess=postprocess, get_real=get_real
)

@override
def sample(
self, rng_key: PRNGKeyArray, *, postprocess: bool = True, get_real: bool = True
) -> (
Float[Array, "{self.instrument_config.y_dim} {self.instrument_config.x_dim}"]
| Float[
Array,
"{self.instrument_config.padded_y_dim} "
"{self.instrument_config.padded_x_dim}",
]
| Complex[
Array, "{self.instrument_config.y_dim} {self.instrument_config.x_dim//2+1}"
]
| Complex[
Array,
"{self.instrument_config.padded_y_dim} "
"{self.instrument_config.padded_x_dim//2+1}",
]
):
keys = jax.random.split(rng_key)
# Compute the squared wavefunction
theory = self.scattering_theory
fourier_squared_wavefunction_at_detector_plane = (
theory.compute_fourier_squared_wavefunction_at_detector_plane(
self.instrument_config, keys[0]
return self._maybe_postprocess(
fourier_expected_electron_events,
postprocess=postprocess,
get_real=get_real,
)
else:
keys = jax.random.split(rng_key)
# Compute the squared wavefunction
theory = self.scattering_theory
fourier_squared_wavefunction_at_detector_plane = (
theory.compute_fourier_squared_wavefunction_at_detector_plane(
self.instrument_config, keys[0]
)
)
# ... now measure the detector readout
fourier_detector_readout = self.detector.compute_detector_readout(
keys[1],
fourier_squared_wavefunction_at_detector_plane,
self.instrument_config,
)
)
# ... now measure the detector readout
fourier_detector_readout = self.detector.compute_detector_readout(
keys[1],
fourier_squared_wavefunction_at_detector_plane,
self.instrument_config,
)

return self._maybe_postprocess(
fourier_detector_readout,
postprocess=postprocess,
get_real=get_real,
)
return self._maybe_postprocess(
fourier_detector_readout,
postprocess=postprocess,
get_real=get_real,
)
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,5 +168,5 @@ def noisy_model(config, theory_with_solvent, detector):

@pytest.fixture
def test_image(noisy_model):
image = noisy_model.sample(jr.PRNGKey(1234))
image = noisy_model.render(jr.PRNGKey(1234))
return rfftn(image)
4 changes: 2 additions & 2 deletions tests/test_helix.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_superposition_pipeline_without_conformation(sample_subunit_mrc_path, co
instrument_config=config, scattering_theory=theory
)
_ = pipeline.render()
_ = pipeline.sample(jax.random.PRNGKey(0))
_ = pipeline.render(jax.random.PRNGKey(0))


def test_superposition_pipeline_with_conformation(sample_subunit_mrc_path, config):
Expand All @@ -85,7 +85,7 @@ def test_superposition_pipeline_with_conformation(sample_subunit_mrc_path, confi
instrument_config=config, scattering_theory=theory
)
_ = pipeline.render()
_ = pipeline.sample(jax.random.PRNGKey(0))
_ = pipeline.render(jax.random.PRNGKey(0))


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def test_fourier_vs_real_normalized_image(noisy_model):
key = jax.random.PRNGKey(1234)
im1 = normalize_image(noisy_model.sample(key, get_real=True), is_real=True)
im1 = normalize_image(noisy_model.render(key, get_real=True), is_real=True)
im2 = irfftn(
normalize_image(
noisy_model.render(get_real=False),
Expand Down

0 comments on commit 91233d9

Please sign in to comment.