Skip to content

Commit

Permalink
bug fix in masking experimental data.. tensorflow-nufft is now default.
Browse files Browse the repository at this point in the history
  • Loading branch information
mjo22 committed Jul 31, 2023
1 parent cd33585 commit 921fb25
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 28 deletions.
4 changes: 2 additions & 2 deletions src/jax_2dtm/simulator/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class Image(metaclass=ABCMeta):
filters: InitVar[list[Filter] | None] = None
masks: InitVar[list[Filter] | None] = None
observed: InitVar[Array | None] = None
_process_observed: bool = True
_process_observed: bool = field(pytree_node=False, default=True)

def __post_init__(self, filters, masks, observed):
# Set filters
Expand Down Expand Up @@ -227,7 +227,7 @@ def sample(self, state: Optional[ParameterState] = None) -> Array:
state = state or self.state
simulated = self.render(state)
noise = state.noise.sample(self.config.freqs * self.config.pixel_size)
return simulated + self.mask(noise)
return simulated + fft(self.mask(ifft(noise)))

def log_likelihood(self, state: Optional[ParameterState] = None) -> Scalar:
"""Evaluate the log-likelihood of the data given a parameter set."""
Expand Down
4 changes: 2 additions & 2 deletions src/jax_2dtm/simulator/optics.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ class CTFOptics(Optics):
b_factor : `jax_2dtm.types.Scalar`
"""

defocus_u: Scalar = 8000.0
defocus_v: Scalar = 8000.0
defocus_u: Scalar = 10000.0
defocus_v: Scalar = 10000.0
defocus_angle: Scalar = 0.0
voltage: Scalar = 300.0
spherical_aberration: Scalar = 2.7
Expand Down
48 changes: 24 additions & 24 deletions src/jax_2dtm/utils/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

import jax
import jax.numpy as jnp
#import tensorflow_nufft as tfft
#from jax.experimental import jax2tf
import tensorflow_nufft as tfft
from jax.experimental import jax2tf

from jax_finufft import nufft1
# from jax_finufft import nufft1
from jax.scipy import special

from .fft import fftfreqs1d
Expand Down Expand Up @@ -57,15 +57,15 @@ def nufft(
"""
complex_density = density.astype(complex)
periodic_coords = 2 * jnp.pi * coords / box_size
#nufft1 = jax2tf.call_tf(
# _tf_nufft1,
# output_shape_dtype=jax.ShapeDtypeStruct(shape, complex_density.dtype),
#)
#ft = nufft1(
# complex_density, jnp.flip(periodic_coords, axis=-1), shape, eps
#)
x, y = periodic_coords.T
ft = nufft1(shape, complex_density, -y, -x, eps=eps)
nufft1 = jax2tf.call_tf(
_tf_nufft1,
output_shape_dtype=jax.ShapeDtypeStruct(shape, complex_density.dtype),
)
ft = nufft1(
complex_density, jnp.flip(periodic_coords, axis=-1), shape, eps
)
# x, y = periodic_coords.T
# ft = nufft1(shape, complex_density, -y, -x, eps=eps)

return ft

Expand Down Expand Up @@ -109,18 +109,18 @@ def integrate_gaussians(
return image


#def _tf_nufft1(source, points, shape, tol):
# """
# Wrapper for type-1 non-uniform FFT
# from tensorflow-nufft.
# """
# return tfft.nufft(
# source,
# points,
# grid_shape=shape,
# transform_type="type_1",
# tol=tol.numpy(),
# )
def _tf_nufft1(source, points, shape, tol):
"""
Wrapper for type-1 non-uniform FFT
from tensorflow-nufft.
"""
return tfft.nufft(
source,
points,
grid_shape=shape,
transform_type="type_1",
tol=tol.numpy(),
)


@jax.jit
Expand Down

0 comments on commit 921fb25

Please sign in to comment.