Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed the momentum exchange method in the JAX backend. #101

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions xlb/operator/boundary_condition/bc_extrapolation_outflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@ def __init__(
mesh_vertices,
)

# Set the flag for auxilary data recovery
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this removed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was not needed for ExtrapolationOutflow BC. that BC directly uses _f_post_collision to put all the necessary ingredients.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure. We created reconstruction originally for this BC!:

        # Update the auxilary data for this BC using the neighbour's populations stored in f_aux and
        # f_pre (post-streaming values of the current voxel). We use directions that leave the domain
        # for storing this prepared data.

If you don[t recover, you'll 100% overwrite this data...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed those comments since then. I can remove this part and add it to another PR if that helps you understand why but this BC assembles all the information and stores it directly in _f_post_collision which is then recovered when we write in f_1 (not f_0 as in bc_recovery). Just like how halfway bc uses post-collision values in the opposite direction taken from f_0, this BC also takes "post-collision" values taken from f_0 which are hand crafted earlier inside a function that is called "update_bc_auxilary_data".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to think about this more to make sure it is correct. We're rushing PR merges for no reason and we're reverting them back or fix them almost immedately a few days later. I'll check this next week.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not reverting anything! There are minor things that have been missed in the reivews or unknown new bugs that are discovered.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Hesam. I took a look at this. I agree that it is not needed. Pls make this change only now that you're at it:
Pls use underline for _f_pre... for the register values in bc_extrapolation_outflow.py

Thanks

self.needs_aux_recovery = True

# find and store the normal vector using indices
self._get_normal_vec(indices)

Expand Down Expand Up @@ -159,15 +156,15 @@ def functional(
missing_mask: Any,
f_0: Any,
f_1: Any,
f_pre: Any,
f_post: Any,
_f_pre: Any,
_f_post: Any,
):
# Post-streaming values are only modified at missing direction
_f = f_post
_f = _f_post
for l in range(self.velocity_set.q):
# If the mask is missing then take the opposite index
if missing_mask[l] == wp.uint8(1):
_f[l] = f_pre[_opp_indices[l]]
_f[l] = _f_pre[_opp_indices[l]]
return _f

@wp.func
Expand All @@ -177,13 +174,13 @@ def update_bc_auxilary_data(
missing_mask: Any,
f_0: Any,
f_1: Any,
f_pre: Any,
f_post: Any,
_f_pre: Any,
_f_post: Any,
):
# Update the auxilary data for this BC using the neighbour's populations stored in f_aux and
# f_pre (post-streaming values of the current voxel). We use directions that leave the domain
# for storing this prepared data.
_f = f_post
_f = _f_post
nv = get_normal_vectors(missing_mask)
for l in range(self.velocity_set.q):
if missing_mask[l] == wp.uint8(1):
Expand All @@ -194,19 +191,19 @@ def update_bc_auxilary_data(
pull_index[d] = index[d] - (_c[d, l] + nv[d])
# The following is the post-streaming values of the neighbor cell
f_aux = self.compute_dtype(f_0[l, pull_index[0], pull_index[1], pull_index[2]])
_f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * f_pre[l] + sound_speed * f_aux
_f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * _f_pre[l] + sound_speed * f_aux
return _f

kernel = self._construct_kernel(functional)

return (functional, update_bc_auxilary_data), kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask):
def warp_implementation(self, _f_pre, _f_post, bc_mask, missing_mask):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
inputs=[f_pre, f_post, bc_mask, missing_mask],
dim=f_pre.shape[1:],
inputs=[_f_pre, _f_post, bc_mask, missing_mask],
dim=_f_pre.shape[1:],
)
return f_post
return _f_post
14 changes: 9 additions & 5 deletions xlb/operator/force/momentum_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@ def __init__(

@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0))
def jax_implementation(self, f, bc_mask, missing_mask):
def jax_implementation(self, f_0, f_1, bc_mask, missing_mask):
"""
Parameters
----------
f : jax.numpy.ndarray
f_0 : jax.numpy.ndarray
The post-collision distribution function at each node in the grid.
f_1 : jax.numpy.ndarray
The buffer field the same size as f_0 (only given as input for consistency with the WARP backened API.)
bc_mask : jax.numpy.ndarray
A grid field with 0 everywhere except for boundary nodes which are designated
by their respective boundary id's.
Expand All @@ -69,7 +71,7 @@ def jax_implementation(self, f, bc_mask, missing_mask):
The force exerted on the solid geometry at each boundary node.
"""
# Give the input post-collision populations, streaming once and apply the BC the find post-stream values.
f_post_collision = f
f_post_collision = f_0
f_post_stream = self.stream(f_post_collision)
f_post_stream = self.no_slip_bc_instance(f_post_collision, f_post_stream, bc_mask, missing_mask)

Expand All @@ -79,11 +81,13 @@ def jax_implementation(self, f, bc_mask, missing_mask):
boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1)))

# the following will return force as a grid-based field with zero everywhere except for boundary nodes.
is_edge = jnp.logical_and(boundary, ~missing_mask[0])
opp = self.velocity_set.opp_indices
phi = f_post_collision[opp] + f_post_stream
phi = jnp.where(jnp.logical_and(boundary, missing_mask), phi, 0.0)
phi = jnp.where(jnp.logical_and(missing_mask, is_edge), phi, 0.0)
force = jnp.tensordot(self.velocity_set.c[:, opp], phi, axes=(-1, 0))
return force
force_net = jnp.sum(force, axis=(i + 1 for i in range(self.velocity_set.d)))
return force_net

def _construct_warp(self):
# Set local constants TODO: This is a hack and should be fixed with warp update
Expand Down
2 changes: 1 addition & 1 deletion xlb/velocity_set/d3q19.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class D3Q19(VelocitySet):

def __init__(self, precision_policy, backend):
# Construct the velocity vectors and weights
c = np.array([ci for ci in itertools.product([-1, 0, 1], repeat=3) if np.sum(np.abs(ci)) <= 2]).T
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to make this consistant, pls make sure that D2Q9 is similar. But overall, again this could change in the future, as you have store them symmetrically and avoid storing values for opp index.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is consistent for all 3 lattices.

c = np.array([ci for ci in itertools.product([0, -1, 1], repeat=3) if np.sum(np.abs(ci)) <= 2]).T
w = np.zeros(19)
for i in range(19):
if np.sum(np.abs(c[:, i])) == 0:
Expand Down
Loading