Skip to content

Commit

Permalink
table plans for scattering via fourier slice theorem for now
Browse files Browse the repository at this point in the history
  • Loading branch information
mjo22 committed Jul 30, 2023
1 parent f59255b commit cd33585
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 51 deletions.
17 changes: 3 additions & 14 deletions notebooks/test-pipeline.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"outputs": [],
"source": [
"# Image simulator\n",
"from jax_2dtm.simulator import NufftScattering, FourierSliceScattering\n",
"from jax_2dtm.simulator import NufftScattering\n",
"from jax_2dtm.simulator import ScatteringImage, OpticsImage, GaussianImage\n",
"from jax_2dtm.simulator import AntiAliasingFilter, WhiteningFilter, CircularMask\n",
"from jax_2dtm.simulator import EulerPose, CTFOptics, Intensity, ExponentialNoise, WhiteNoise, ParameterState\n",
Expand All @@ -68,11 +68,8 @@
"outputs": [],
"source": [
"# Configure image data and read template as point cloud\n",
"config = FourierSliceScattering((80, 80), pixel_size, pad_scale=1)\n",
"cloud = load_grid_as_cloud(filename, config, real=False)\n",
"#config = NufftScattering((95, 95), pixel_size, pad_scale=2.2, eps=1e-5)\n",
"#config = GaussianScattering((320, 320), pixel_size, pad_scale=1.4, scale=1/3)\n",
"#cloud = load_grid_as_cloud(filename, config, atol=1e-6)"
"config = NufftScattering((80, 80), pixel_size, pad_scale=2.2, eps=1e-5)\n",
"cloud = load_grid_as_cloud(filename, config, atol=1e-6)"
]
},
{
Expand All @@ -83,7 +80,6 @@
"source": [
"# Initialize model, parameters, and compute image\n",
"pose = EulerPose(-50.0, 30.0, np.pi / 4, np.pi / 10, -np.pi / 6)\n",
"pose = EulerPose()\n",
"optics = CTFOptics()\n",
"noise = ExponentialNoise(key=jax.random.PRNGKey(seed=0), kappa=0.1, xi=6.0, alpha=0.5)\n",
"intensity = Intensity()\n",
Expand Down Expand Up @@ -277,13 +273,6 @@
"# Benchmark gradient\n",
"%timeit grad = grad_loss(params)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
7 changes: 2 additions & 5 deletions src/jax_2dtm/io/load_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import jax.numpy as jnp
from typing import Any
from ..simulator import Cloud, ImageConfig
from ..utils import fftfreqs, fft, pad
from ..utils import fftfreqs, fft
from ..core import Array, ArrayLike


Expand Down Expand Up @@ -97,7 +97,7 @@ def coordinatize(
Camera pixel size.
real : `bool`
If ``True``, return flattened density and coordinate
system in real space. If ``False``, return structured
system in real space. If ``False``, return density
coordinates in Fourier space.
kwargs
Keyword arguments passed to ``np.isclose``.
Expand All @@ -119,9 +119,6 @@ def coordinatize(
mask = np.where(~np.isclose(flat, 0.0, **kwargs))
density = flat[mask]
else:
if shape != tuple(ndim * [shape[0]]):
template = pad(template, tuple(ndim * [max(shape)]))
shape = template.shape
density = fft(template)
mask = True

Expand Down
17 changes: 6 additions & 11 deletions src/jax_2dtm/simulator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@
# Functional API
"rotate_and_translate_rpy",
"rotate_and_translate_wxyz",
"rotate_rpy",
"rotate_wxyz",
"translate",
"project_with_nufft",
"project_with_gaussians",
"project_with_slice",
# "project_with_slice",
"compute_anti_aliasing_filter",
"compute_whitening_filter",
"compute_circular_mask",
Expand All @@ -17,7 +14,7 @@
"ImageConfig",
"NufftScattering",
"GaussianScattering",
"FourierSliceScattering",
# "FourierSliceScattering"
# Point clouds
"Cloud",
# Filters
Expand Down Expand Up @@ -57,9 +54,6 @@
from .pose import (
rotate_and_translate_rpy,
rotate_and_translate_wxyz,
rotate_rpy,
rotate_wxyz,
translate,
EulerPose,
QuaternionPose,
)
Expand All @@ -68,14 +62,15 @@
from .scattering import (
project_with_nufft,
project_with_gaussians,
project_with_slice,
ImageConfig,
NufftScattering,
GaussianScattering,
FourierSliceScattering,
# FourierSliceScattering,
)

ScatteringConfig = Union[NufftScattering, GaussianScattering]
ScatteringConfig = Union[
NufftScattering, GaussianScattering
] # , FourierSliceScattering]
from .cloud import Cloud
from .filters import (
compute_anti_aliasing_filter,
Expand Down
39 changes: 23 additions & 16 deletions src/jax_2dtm/simulator/pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@
"rotate_and_translate_wxyz",
"rotate_rpy",
"rotate_wxyz",
"translate",
"shift_phase",
"Pose",
"EulerPose",
"QuaternionPose",
]

from abc import ABCMeta, abstractmethod
from functools import partial

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -84,9 +83,16 @@ def transform(
coordinates, *self.iter_data()
)
else:
return translate(
density, coordinates, *self.iter_data()[:2]
), rotate_rpy(coordinates, *self.iter_data()[2:], inverse=True)
raise NotImplementedError
# rotated_coordinates = rotate_rpy(
# coordinates, *self.iter_data()[2:]
# )
# shifted_density = shift_phase(
# density,
# rotated_coordinates,
# *self.iter_data()[:2],
# )
# return shifted_density, rotated_coordinates


@dataclass
Expand Down Expand Up @@ -118,9 +124,14 @@ def transform(
coordinates, *self.iter_data()
)
else:
return translate(
density, coordinates, *self.iter_data()[:2]
), rotate_wxyz(coordinates, *self.iter_data()[2:], inverse=True)
raise NotImplementedError
# rotated_coordinates = rotate_wxyz(
# coordinates, *self.iter_data()[2:]
# )
# shifted_density = shift_phase(
# density, rotated_coordinates, *self.iter_data()[:2]
# )
# return shifted_density, rotated_coordinates


@jax.jit
Expand Down Expand Up @@ -203,13 +214,12 @@ def rotate_and_translate_wxyz(
return transformed


# @partial(jax.jit, static_argnums=-1)
@jax.jit
def rotate_rpy(
coords: Array,
phi: float,
theta: float,
psi: float,
inverse: bool = False,
) -> Array:
r"""
Compute a coordinate rotation from
Expand All @@ -234,20 +244,18 @@ def rotate_rpy(
N1, N2, N3 = coords.shape[:-1]
N = N1 * N2 * N3
rotation = SO3.from_rpy_radians(phi, theta, psi)
rotation = rotation.inverse() if inverse else rotation
transformed = jax.vmap(rotation.apply)(coords.reshape((N, 3)))

return transformed.reshape(coords.shape)


# @partial(jax.jit, static_argnums=-1)
@jax.jit
def rotate_wxyz(
coords: Array,
qw: float,
qx: float,
qy: float,
qz: float,
inverse: bool = False,
) -> Array:
r"""
Compute a coordinate rotation from a quaternion.
Expand All @@ -270,14 +278,13 @@ def rotate_wxyz(
N = N1 * N2 * N3
wxyz = jnp.array([qw, qx, qy, qz])
rotation = SO3.from_quaternion_xyzw(wxyz)
rotation = rotation.inverse() if inverse else rotation
transformed = jax.vmap(rotation.apply)(coords.reshape((N, 3)))

return transformed.reshape(coords.shape)


@jax.jit
def translate(
def shift_phase(
density: Array,
coords: Array,
tx: float,
Expand All @@ -304,7 +311,7 @@ def translate(
Rotated and translated coordinate system.
"""
xyz = jnp.array([tx, ty, 0.0])
shift = jnp.exp(1.0j * jnp.matmul(coords, xyz))
shift = jnp.exp(1.0j * 2 * jnp.pi * jnp.matmul(coords, xyz))
transformed = density * shift

return transformed
13 changes: 8 additions & 5 deletions src/jax_2dtm/simulator/scattering.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,19 @@ class FourierSliceScattering(ScatteringConfig):
Fourier-projection slice theorem.
"""

order: int = field(pytree_node=False, default=1)

def project(self, *args: Any, **kwargs: Any):
"""
Compute image by interpolating onto the
imaging plane.
"""
density, coordinates, _ = args
projection = project_with_slice(
density, coordinates, self.pixel_size, **kwargs
)
return resize(projection, self.padded_shape, antialias=False)
raise NotImplementedError
# density, coordinates, _ = args
# projection = project_with_slice(
# density, coordinates, self.pixel_size, order=self.order, **kwargs
# )
# return resize(projection, self.padded_shape, antialias=False)


@dataclass
Expand Down

0 comments on commit cd33585

Please sign in to comment.