Skip to content

Commit

Permalink
Merge branch 'main' into voltage-refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mjo22 authored Apr 7, 2024
2 parents c17b3bb + 35f906e commit 6900b03
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions src/cryojax/coordinates/_coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing_extensions import Self

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from equinox import AbstractVar
Expand Down Expand Up @@ -90,7 +89,7 @@ def __init__(
def get(
self,
) -> Float[Array, "y_dim x_dim 2"] | Float[Array, "z_dim y_dim x_dim 3"]:
return jax.lax.stop_gradient(self.array)
return self.array


class FrequencyGrid(AbstractCoordinates, strict=True):
Expand All @@ -113,7 +112,7 @@ def __init__(
def get(
self,
) -> Float[Array, "y_dim x_dim 2"] | Float[Array, "z_dim y_dim x_dim 3"]:
return jax.lax.stop_gradient(self.array)
return self.array


class FrequencySlice(AbstractCoordinates, strict=True):
Expand Down Expand Up @@ -149,7 +148,7 @@ def __init__(
self.array = frequency_slice

def get(self) -> Float[Array, "1 y_dim x_dim 3"]:
return jax.lax.stop_gradient(self.array)
return self.array


def make_coordinates(
Expand Down

0 comments on commit 6900b03

Please sign in to comment.