From 76597d4f5cf8f88b69272e62cdb7cf9e43c62969 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 20 Aug 2024 08:14:29 -0400 Subject: [PATCH] used lax.broadcast_in_dim instead of jnp.repeat plus other minor changes --- .../boundary_condition/bc_fullway_bounce_back.py | 4 +++- .../boundary_condition/bc_halfway_bounce_back.py | 4 +++- xlb/operator/boundary_condition/bc_regularized.py | 14 ++++++++------ xlb/operator/boundary_condition/bc_zouhe.py | 14 ++++++++------ xlb/operator/macroscopic/__init__.py | 2 +- xlb/operator/macroscopic/zero_first_moments.py | 2 +- 6 files changed, 24 insertions(+), 16 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 85a3c356..0083bae7 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -4,6 +4,7 @@ import jax.numpy as jnp from jax import jit +import jax.lax as lax from functools import partial import warp as wp from typing import Any @@ -47,7 +48,8 @@ def __init__( @partial(jit, static_argnums=(0)) def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): boundary = boundary_mask == self.id - boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) + new_shape = (self.velocity_set.q,) + boundary.shape[1:] + boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) return jnp.where(boundary, f_pre[self.velocity_set.opp_indices, ...], f_post) def _construct_warp(self): diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index df947c67..2ed00677 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -4,6 +4,7 @@ import jax.numpy as jnp from jax import jit +import jax.lax as lax from functools import partial import warp as wp from typing import Any @@ -50,7 +51,8 @@ def __init__( @partial(jit, static_argnums=(0)) def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): boundary = boundary_mask == self.id - boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) + new_shape = (self.velocity_set.q,) + boundary.shape[1:] + boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) return jnp.where( jnp.logical_and(missing_mask, boundary), f_pre[self.velocity_set.opp_indices], diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 3d38f896..413a37f6 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -4,6 +4,7 @@ import jax.numpy as jnp from jax import jit +import jax.lax as lax from functools import partial import warp as wp from typing import Any @@ -139,7 +140,8 @@ def regularize_fpop(self, fpop, feq): def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): # creat a mask to slice boundary cells boundary = boundary_mask == self.id - boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) + new_shape = (self.velocity_set.q,) + boundary.shape[1:] + boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) # compute the equilibrium based on prescribed values and the type of BC feq = self.calculate_equilibrium(f_post, missing_mask) @@ -185,7 +187,7 @@ def get_normal_vectors_2d( return normals @wp.func - def _helper_function( + def _get_fsum( fpop: Any, missing_mask: Any, ): @@ -256,7 +258,7 @@ def functional3d_velocity( normals = get_normal_vectors_3d(missing_mask) # calculate rho - fsum = _helper_function(_f, missing_mask) + fsum = _get_fsum(_f, missing_mask) unormal = self.compute_dtype(0.0) for d in range(_d): unormal += _u[d] * normals[d] @@ -283,7 +285,7 @@ def functional3d_pressure( normals = get_normal_vectors_3d(missing_mask) # calculate velocity - fsum = _helper_function(_f, missing_mask) + fsum = _get_fsum(_f, missing_mask) unormal = -1.0 + fsum / _rho _u = unormal * normals @@ -308,7 +310,7 @@ def functional2d_velocity( normals = get_normal_vectors_2d(missing_mask) # calculate rho - fsum = _helper_function(_f, missing_mask) + fsum = _get_fsum(_f, missing_mask) unormal = self.compute_dtype(0.0) for d in range(_d): unormal += _u[d] * normals[d] @@ -335,7 +337,7 @@ def functional2d_pressure( normals = get_normal_vectors_2d(missing_mask) # calculate velocity - fsum = _helper_function(_f, missing_mask) + fsum = _get_fsum(_f, missing_mask) unormal = -1.0 + fsum / _rho _u = unormal * normals diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index fbb463b3..3b69b21a 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -4,6 +4,7 @@ import jax.numpy as jnp from jax import jit +import jax.lax as lax from functools import partial import warp as wp from typing import Any @@ -156,7 +157,8 @@ def bounceback_nonequilibrium(self, fpop, feq, missing_mask): def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): # creat a mask to slice boundary cells boundary = boundary_mask == self.id - boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) + new_shape = (self.velocity_set.q,) + boundary.shape[1:] + boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) # compute the equilibrium based on prescribed values and the type of BC feq = self.calculate_equilibrium(f_post, missing_mask) @@ -193,7 +195,7 @@ def get_normal_vectors_2d( return normals @wp.func - def _helper_function( + def _get_fsum( fpop: Any, missing_mask: Any, ): @@ -238,7 +240,7 @@ def functional3d_velocity( normals = get_normal_vectors_3d(missing_mask) # calculate rho - fsum = _helper_function(_f, missing_mask) + fsum = _get_fsum(_f, missing_mask) unormal = self.compute_dtype(0.0) for d in range(_d): unormal += _u[d] * normals[d] @@ -262,7 +264,7 @@ def functional3d_pressure( normals = get_normal_vectors_3d(missing_mask) # calculate velocity - fsum = _helper_function(_f, missing_mask) + fsum = _get_fsum(_f, missing_mask) unormal = -1.0 + fsum / _rho _u = unormal * normals @@ -284,7 +286,7 @@ def functional2d_velocity( normals = get_normal_vectors_2d(missing_mask) # calculate rho - fsum = _helper_function(_f, missing_mask) + fsum = _get_fsum(_f, missing_mask) unormal = self.compute_dtype(0.0) for d in range(_d): unormal += _u[d] * normals[d] @@ -308,7 +310,7 @@ def functional2d_pressure( normals = get_normal_vectors_2d(missing_mask) # calculate velocity - fsum = _helper_function(_f, missing_mask) + fsum = _get_fsum(_f, missing_mask) unormal = -1.0 + fsum / _rho _u = unormal * normals diff --git a/xlb/operator/macroscopic/__init__.py b/xlb/operator/macroscopic/__init__.py index 42b747f3..38195cd4 100644 --- a/xlb/operator/macroscopic/__init__.py +++ b/xlb/operator/macroscopic/__init__.py @@ -1,2 +1,2 @@ -from xlb.operator.macroscopic.zero_first_moments import FirstAndZerothMoment as Macroscopic +from xlb.operator.macroscopic.zero_first_moments import ZeroAndFirstMoments as Macroscopic from xlb.operator.macroscopic.second_moment import SecondMoment as SecondMoment diff --git a/xlb/operator/macroscopic/zero_first_moments.py b/xlb/operator/macroscopic/zero_first_moments.py index 7e64fc3f..fbf7c939 100644 --- a/xlb/operator/macroscopic/zero_first_moments.py +++ b/xlb/operator/macroscopic/zero_first_moments.py @@ -10,7 +10,7 @@ from xlb.operator.operator import Operator -class FirstAndZerothMoment(Operator): +class ZeroAndFirstMoments(Operator): """ A class to compute first and zeroth moments of distribution functions.