Skip to content

Commit

Permalink
Merge pull request #195 from mjo22/gradient-testing
Browse files Browse the repository at this point in the history
Fix bug that stops gradients with respect to rotations and pixel size rescalings. Closes #194.
  • Loading branch information
mjo22 authored Apr 3, 2024
2 parents 3b14a41 + 6fa9044 commit 35f906e
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
7 changes: 3 additions & 4 deletions src/cryojax/coordinates/_coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing_extensions import Self

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from equinox import AbstractVar
Expand Down Expand Up @@ -91,7 +90,7 @@ def __init__(
self.array = make_coordinates(shape, grid_spacing)

def get(self) -> ImageCoords | VolumeCoords:
return jax.lax.stop_gradient(self.array)
return self.array


class FrequencyGrid(AbstractCoordinates, strict=True):
Expand All @@ -110,7 +109,7 @@ def __init__(
self.array = make_frequencies(shape, grid_spacing, half_space=half_space)

def get(self) -> ImageCoords | VolumeCoords:
return jax.lax.stop_gradient(self.array)
return self.array


class FrequencySlice(AbstractCoordinates, strict=True):
Expand Down Expand Up @@ -146,7 +145,7 @@ def __init__(
self.array = frequency_slice

def get(self) -> VolumeSliceCoords:
return jax.lax.stop_gradient(self.array)
return self.array


def make_coordinates(
Expand Down
4 changes: 2 additions & 2 deletions src/cryojax/image/_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import jax.numpy as jnp

from ..typing import ComplexImage, Image, RealNumber
from ..typing import ComplexImage, ComplexNumber, Image, RealNumber


def normalize_image(
Expand Down Expand Up @@ -91,7 +91,7 @@ def compute_mean_and_std_from_fourier_image(
*,
half_space: bool = True,
shape_in_real_space: Optional[tuple[int, int]] = None,
) -> tuple[RealNumber, RealNumber]:
) -> tuple[ComplexNumber, RealNumber]:
"""Compute the mean and standard deviation in real space from
an image in fourier space.
"""
Expand Down
6 changes: 3 additions & 3 deletions src/cryojax/simulator/_integrators/_fourier_slice_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ def extract_slice(


def extract_slice_with_cubic_spline(
spline_coefficients: Complex[Array, "N+2 N+2 N+2"],
frequency_slice: Float[Array, "1 N N 3"],
spline_coefficients: Complex[Array, "N N N"],
frequency_slice: Float[Array, "1 N-2 N-2 3"],
**kwargs: Any,
) -> Complex[Array, "N N//2+1"]:
) -> Complex[Array, " N-2 (N-2)//2+1"]:
"""
Project and interpolate 3D volume point cloud
onto imaging plane using the fourier slice theorem, using cubic
Expand Down

0 comments on commit 35f906e

Please sign in to comment.