From 921fb25f526e4f589a6ca586268fe843bf3d9a14 Mon Sep 17 00:00:00 2001 From: mjo22 Date: Mon, 31 Jul 2023 01:34:08 -0400 Subject: [PATCH] bug fix in masking experimental data.. tensorflow-nufft is now default. --- src/jax_2dtm/simulator/image.py | 4 +-- src/jax_2dtm/simulator/optics.py | 4 +-- src/jax_2dtm/utils/integration.py | 48 +++++++++++++++---------------- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/jax_2dtm/simulator/image.py b/src/jax_2dtm/simulator/image.py index bccf715c..54d2af33 100644 --- a/src/jax_2dtm/simulator/image.py +++ b/src/jax_2dtm/simulator/image.py @@ -64,7 +64,7 @@ class Image(metaclass=ABCMeta): filters: InitVar[list[Filter] | None] = None masks: InitVar[list[Filter] | None] = None observed: InitVar[Array | None] = None - _process_observed: bool = True + _process_observed: bool = field(pytree_node=False, default=True) def __post_init__(self, filters, masks, observed): # Set filters @@ -227,7 +227,7 @@ def sample(self, state: Optional[ParameterState] = None) -> Array: state = state or self.state simulated = self.render(state) noise = state.noise.sample(self.config.freqs * self.config.pixel_size) - return simulated + self.mask(noise) + return simulated + fft(self.mask(ifft(noise))) def log_likelihood(self, state: Optional[ParameterState] = None) -> Scalar: """Evaluate the log-likelihood of the data given a parameter set.""" diff --git a/src/jax_2dtm/simulator/optics.py b/src/jax_2dtm/simulator/optics.py index 1c632465..718537be 100644 --- a/src/jax_2dtm/simulator/optics.py +++ b/src/jax_2dtm/simulator/optics.py @@ -78,8 +78,8 @@ class CTFOptics(Optics): b_factor : `jax_2dtm.types.Scalar` """ - defocus_u: Scalar = 8000.0 - defocus_v: Scalar = 8000.0 + defocus_u: Scalar = 10000.0 + defocus_v: Scalar = 10000.0 defocus_angle: Scalar = 0.0 voltage: Scalar = 300.0 spherical_aberration: Scalar = 2.7 diff --git a/src/jax_2dtm/utils/integration.py b/src/jax_2dtm/utils/integration.py index ce787a7e..d07ce0f6 100644 --- a/src/jax_2dtm/utils/integration.py +++ b/src/jax_2dtm/utils/integration.py @@ -6,10 +6,10 @@ import jax import jax.numpy as jnp -#import tensorflow_nufft as tfft -#from jax.experimental import jax2tf +import tensorflow_nufft as tfft +from jax.experimental import jax2tf -from jax_finufft import nufft1 +# from jax_finufft import nufft1 from jax.scipy import special from .fft import fftfreqs1d @@ -57,15 +57,15 @@ def nufft( """ complex_density = density.astype(complex) periodic_coords = 2 * jnp.pi * coords / box_size - #nufft1 = jax2tf.call_tf( - # _tf_nufft1, - # output_shape_dtype=jax.ShapeDtypeStruct(shape, complex_density.dtype), - #) - #ft = nufft1( - # complex_density, jnp.flip(periodic_coords, axis=-1), shape, eps - #) - x, y = periodic_coords.T - ft = nufft1(shape, complex_density, -y, -x, eps=eps) + nufft1 = jax2tf.call_tf( + _tf_nufft1, + output_shape_dtype=jax.ShapeDtypeStruct(shape, complex_density.dtype), + ) + ft = nufft1( + complex_density, jnp.flip(periodic_coords, axis=-1), shape, eps + ) + # x, y = periodic_coords.T + # ft = nufft1(shape, complex_density, -y, -x, eps=eps) return ft @@ -109,18 +109,18 @@ def integrate_gaussians( return image -#def _tf_nufft1(source, points, shape, tol): -# """ -# Wrapper for type-1 non-uniform FFT -# from tensorflow-nufft. -# """ -# return tfft.nufft( -# source, -# points, -# grid_shape=shape, -# transform_type="type_1", -# tol=tol.numpy(), -# ) +def _tf_nufft1(source, points, shape, tol): + """ + Wrapper for type-1 non-uniform FFT + from tensorflow-nufft. + """ + return tfft.nufft( + source, + points, + grid_shape=shape, + transform_type="type_1", + tol=tol.numpy(), + ) @jax.jit