diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index b968b6a5..884e691e 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -53,9 +53,6 @@ def __init__( mesh_vertices, ) - # Set the flag for auxilary data recovery - self.needs_aux_recovery = True - # find and store the normal vector using indices self._get_normal_vec(indices) @@ -159,15 +156,15 @@ def functional( missing_mask: Any, f_0: Any, f_1: Any, - f_pre: Any, - f_post: Any, + _f_pre: Any, + _f_post: Any, ): # Post-streaming values are only modified at missing direction - _f = f_post + _f = _f_post for l in range(self.velocity_set.q): # If the mask is missing then take the opposite index if missing_mask[l] == wp.uint8(1): - _f[l] = f_pre[_opp_indices[l]] + _f[l] = _f_pre[_opp_indices[l]] return _f @wp.func @@ -177,13 +174,13 @@ def update_bc_auxilary_data( missing_mask: Any, f_0: Any, f_1: Any, - f_pre: Any, - f_post: Any, + _f_pre: Any, + _f_post: Any, ): # Update the auxilary data for this BC using the neighbour's populations stored in f_aux and # f_pre (post-streaming values of the current voxel). We use directions that leave the domain # for storing this prepared data. - _f = f_post + _f = _f_post nv = get_normal_vectors(missing_mask) for l in range(self.velocity_set.q): if missing_mask[l] == wp.uint8(1): @@ -194,7 +191,7 @@ def update_bc_auxilary_data( pull_index[d] = index[d] - (_c[d, l] + nv[d]) # The following is the post-streaming values of the neighbor cell f_aux = self.compute_dtype(f_0[l, pull_index[0], pull_index[1], pull_index[2]]) - _f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * f_pre[l] + sound_speed * f_aux + _f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * _f_pre[l] + sound_speed * f_aux return _f kernel = self._construct_kernel(functional) @@ -202,11 +199,11 @@ def update_bc_auxilary_data( return (functional, update_bc_auxilary_data), kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): + def warp_implementation(self, _f_pre, _f_post, bc_mask, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_pre, f_post, bc_mask, missing_mask], - dim=f_pre.shape[1:], + inputs=[_f_pre, _f_post, bc_mask, missing_mask], + dim=_f_pre.shape[1:], ) - return f_post + return _f_post diff --git a/xlb/operator/force/momentum_transfer.py b/xlb/operator/force/momentum_transfer.py index dbd53079..1c6255d3 100644 --- a/xlb/operator/force/momentum_transfer.py +++ b/xlb/operator/force/momentum_transfer.py @@ -50,12 +50,14 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) - def jax_implementation(self, f, bc_mask, missing_mask): + def jax_implementation(self, f_0, f_1, bc_mask, missing_mask): """ Parameters ---------- - f : jax.numpy.ndarray + f_0 : jax.numpy.ndarray The post-collision distribution function at each node in the grid. + f_1 : jax.numpy.ndarray + The buffer field the same size as f_0 (only given as input for consistency with the WARP backened API.) bc_mask : jax.numpy.ndarray A grid field with 0 everywhere except for boundary nodes which are designated by their respective boundary id's. @@ -69,7 +71,7 @@ def jax_implementation(self, f, bc_mask, missing_mask): The force exerted on the solid geometry at each boundary node. """ # Give the input post-collision populations, streaming once and apply the BC the find post-stream values. - f_post_collision = f + f_post_collision = f_0 f_post_stream = self.stream(f_post_collision) f_post_stream = self.no_slip_bc_instance(f_post_collision, f_post_stream, bc_mask, missing_mask) @@ -79,11 +81,13 @@ def jax_implementation(self, f, bc_mask, missing_mask): boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) # the following will return force as a grid-based field with zero everywhere except for boundary nodes. + is_edge = jnp.logical_and(boundary, ~missing_mask[0]) opp = self.velocity_set.opp_indices phi = f_post_collision[opp] + f_post_stream - phi = jnp.where(jnp.logical_and(boundary, missing_mask), phi, 0.0) + phi = jnp.where(jnp.logical_and(missing_mask, is_edge), phi, 0.0) force = jnp.tensordot(self.velocity_set.c[:, opp], phi, axes=(-1, 0)) - return force + force_net = jnp.sum(force, axis=(i + 1 for i in range(self.velocity_set.d))) + return force_net def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update diff --git a/xlb/velocity_set/d3q19.py b/xlb/velocity_set/d3q19.py index 4a48c2f0..8d693d7e 100644 --- a/xlb/velocity_set/d3q19.py +++ b/xlb/velocity_set/d3q19.py @@ -16,7 +16,7 @@ class D3Q19(VelocitySet): def __init__(self, precision_policy, backend): # Construct the velocity vectors and weights - c = np.array([ci for ci in itertools.product([-1, 0, 1], repeat=3) if np.sum(np.abs(ci)) <= 2]).T + c = np.array([ci for ci in itertools.product([0, -1, 1], repeat=3) if np.sum(np.abs(ci)) <= 2]).T w = np.zeros(19) for i in range(19): if np.sum(np.abs(c[:, i])) == 0: