Skip to content

Commit

Permalink
Merge pull request #3 from bunkerj/main
Browse files Browse the repository at this point in the history
Fix get_grid()
  • Loading branch information
astanziola authored May 25, 2024
2 parents f060d2d + 2eda951 commit da1a778
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion fno/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
def get_grid(x):
x1 = jnp.linspace(0, 1, x.shape[1])
x2 = jnp.linspace(0, 1, x.shape[2])
jnp.meshgrid(x1, x2, indexing = 'ij')
x1, x2 = jnp.meshgrid(x1, x2, indexing = 'ij')
grid = jnp.expand_dims(jnp.stack([x1, x2], axis=-1), 0)
batched_grid = jnp.repeat(grid, x.shape[0], axis=0)
return batched_grid

0 comments on commit da1a778

Please sign in to comment.