diff --git a/xlb/operator/boundary_condition/__init__.py b/xlb/operator/boundary_condition/__init__.py index 4782ea0e..7c87f58c 100644 --- a/xlb/operator/boundary_condition/__init__.py +++ b/xlb/operator/boundary_condition/__init__.py @@ -1,3 +1,4 @@ +from xlb.operator.boundary_condition.helper_functions_bc import HelperFunctionsBC from xlb.operator.boundary_condition.boundary_condition import BoundaryCondition from xlb.operator.boundary_condition.boundary_condition_registry import BoundaryConditionRegistry from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index 0ff67759..aeefd788 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -16,9 +16,6 @@ ImplementationStep, BoundaryCondition, ) -from xlb.operator.boundary_condition.boundary_condition_registry import ( - boundary_condition_registry, -) class DoNothingBC(BoundaryCondition): diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 260ceb51..85cfd653 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -19,9 +19,6 @@ ImplementationStep, BoundaryCondition, ) -from xlb.operator.boundary_condition.boundary_condition_registry import ( - boundary_condition_registry, -) class EquilibriumBC(BoundaryCondition): diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index fa75490d..b968b6a5 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -19,9 +19,6 @@ ImplementationStep, BoundaryCondition, ) -from xlb.operator.boundary_condition.boundary_condition_registry import ( - boundary_condition_registry, -) class ExtrapolationOutflowBC(BoundaryCondition): diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index bac93878..995e2ff9 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -17,9 +17,6 @@ BoundaryCondition, ImplementationStep, ) -from xlb.operator.boundary_condition.boundary_condition_registry import ( - boundary_condition_registry, -) class FullwayBounceBackBC(BoundaryCondition): diff --git a/xlb/operator/boundary_condition/bc_grads_approximation.py b/xlb/operator/boundary_condition/bc_grads_approximation.py index 217c65c5..22fbb4ec 100644 --- a/xlb/operator/boundary_condition/bc_grads_approximation.py +++ b/xlb/operator/boundary_condition/bc_grads_approximation.py @@ -23,9 +23,6 @@ ImplementationStep, BoundaryCondition, ) -from xlb.operator.boundary_condition.boundary_condition_registry import ( - boundary_condition_registry, -) class GradsApproximationBC(BoundaryCondition): diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index bf04af03..8ede0c8b 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -17,9 +17,6 @@ ImplementationStep, BoundaryCondition, ) -from xlb.operator.boundary_condition.boundary_condition_registry import ( - boundary_condition_registry, -) class HalfwayBounceBackBC(BoundaryCondition): diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 91776e01..1950fc1b 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -14,10 +14,8 @@ from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator -from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC -from xlb.operator.boundary_condition.boundary_condition import ImplementationStep -from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry -from xlb.operator.macroscopic.second_moment import SecondMoment as MomentumFlux +from xlb.operator.boundary_condition import ZouHeBC, HelperFunctionsBC +from xlb.operator.macroscopic import SecondMoment as MomentumFlux class RegularizedBC(ZouHeBC): @@ -64,7 +62,6 @@ def __init__( indices, mesh_vertices, ) - # Overwrite the boundary condition registry id with the bc_type in the name self.momentum_flux = MomentumFlux() @partial(jit, static_argnums=(0,), inline=True) @@ -127,83 +124,12 @@ def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): return f_post def _construct_warp(self): - # assign placeholders for both u and rho based on prescribed_value + # load helper functions + bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend) + # Set local constants _d = self.velocity_set.d _q = self.velocity_set.q - - # Set local constants TODO: This is a hack and should be fixed with warp update - # _u_vec = wp.vec(_d, dtype=self.compute_dtype) - # compute Qi tensor and store it in self - _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) _opp_indices = self.velocity_set.opp_indices - _w = self.velocity_set.w - _c = self.velocity_set.c - _c_float = self.velocity_set.c_float - _qi = self.velocity_set.qi - # TODO: related to _c_float: this is way less than ideal. we should not be making new types - - @wp.func - def _get_fsum( - fpop: Any, - missing_mask: Any, - ): - fsum_known = self.compute_dtype(0.0) - fsum_middle = self.compute_dtype(0.0) - for l in range(_q): - if missing_mask[_opp_indices[l]] == wp.uint8(1): - fsum_known += self.compute_dtype(2.0) * fpop[l] - elif missing_mask[l] != wp.uint8(1): - fsum_middle += fpop[l] - return fsum_known + fsum_middle - - @wp.func - def get_normal_vectors( - missing_mask: Any, - ): - if wp.static(_d == 3): - for l in range(_q): - if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: - return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l]) - else: - for l in range(_q): - if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: - return -_u_vec(_c_float[0, l], _c_float[1, l]) - - @wp.func - def bounceback_nonequilibrium( - fpop: Any, - feq: Any, - missing_mask: Any, - ): - for l in range(_q): - if missing_mask[l] == wp.uint8(1): - fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]] - return fpop - - @wp.func - def regularize_fpop( - fpop: Any, - feq: Any, - ): - """ - Regularizes the distribution functions by adding non-equilibrium contributions based on second moments of fpop. - """ - # Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq} - f_neq = fpop - feq - PiNeq = self.momentum_flux.warp_functional(f_neq) - - # Compute double dot product Qi:Pi1 (where Pi1 = PiNeq) - nt = _d * (_d + 1) // 2 - for l in range(_q): - QiPi1 = self.compute_dtype(0.0) - for t in range(nt): - QiPi1 += _qi[l, t] * PiNeq[t] - - # assign all populations based on eq 45 of Latt et al (2008) - # fneq ~ f^1 - fpop1 = self.compute_dtype(4.5) * _w[l] * QiPi1 - fpop[l] = feq[l] + fpop1 - return fpop @wp.func def functional_velocity( @@ -219,7 +145,7 @@ def functional_velocity( _f = f_post # Find normal vector - normals = get_normal_vectors(missing_mask) + normals = bc_helper.get_normal_vectors(missing_mask) # Find the value of u from the missing directions # Since we are only considering normal velocity, we only need to find one value (stored at the center of f_1) @@ -228,7 +154,7 @@ def functional_velocity( _u = -prescribed_value * normals # calculate rho - fsum = _get_fsum(_f, missing_mask) + fsum = bc_helper.get_bc_fsum(_f, missing_mask) unormal = self.compute_dtype(0.0) for d in range(_d): unormal += _u[d] * normals[d] @@ -236,10 +162,10 @@ def functional_velocity( # impose non-equilibrium bounceback feq = self.equilibrium_operator.warp_functional(_rho, _u) - _f = bounceback_nonequilibrium(_f, feq, missing_mask) + _f = bc_helper.bounceback_nonequilibrium(_f, feq, missing_mask) # Regularize the boundary fpop - _f = regularize_fpop(_f, feq) + _f = bc_helper.regularize_fpop(_f, feq) return _f @wp.func @@ -256,23 +182,23 @@ def functional_pressure( _f = f_post # Find normal vector - normals = get_normal_vectors(missing_mask) + normals = bc_helper.get_normal_vectors(missing_mask) # Find the value of rho from the missing directions # Since we need only one scalar value, we only need to find one value (stored at the center of f_1) _rho = f_1[0, index[0], index[1], index[2]] # calculate velocity - fsum = _get_fsum(_f, missing_mask) + fsum = bc_helper.get_bc_fsum(_f, missing_mask) unormal = -self.compute_dtype(1.0) + fsum / _rho _u = unormal * normals # impose non-equilibrium bounceback feq = self.equilibrium_operator.warp_functional(_rho, _u) - _f = bounceback_nonequilibrium(_f, feq, missing_mask) + _f = bc_helper.bounceback_nonequilibrium(_f, feq, missing_mask) # Regularize the boundary fpop - _f = regularize_fpop(_f, feq) + _f = bc_helper.regularize_fpop(_f, feq) return _f if self.bc_type == "velocity": diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index cfb1a34e..5cad5048 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -18,11 +18,8 @@ ImplementationStep, BoundaryCondition, ) -from xlb.operator.boundary_condition.boundary_condition_registry import ( - boundary_condition_registry, -) +from xlb.operator.boundary_condition import HelperFunctionsBC from xlb.operator.equilibrium import QuadraticEquilibrium -import jax class ZouHeBC(BoundaryCondition): @@ -272,55 +269,12 @@ def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): return f_post def _construct_warp(self): - # assign placeholders for both u and rho based on prescribed_value + # load helper functions + bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend) + # Set local constants _d = self.velocity_set.d _q = self.velocity_set.q - - # Set local constants TODO: This is a hack and should be fixed with warp update - # _u_vec = wp.vec(_d, dtype=self.compute_dtype) - _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) _opp_indices = self.velocity_set.opp_indices - _c = self.velocity_set.c - _c_float = self.velocity_set.c_float - # TODO: this is way less than ideal. we should not be making new types - - @wp.func - def _get_fsum( - fpop: Any, - missing_mask: Any, - ): - fsum_known = self.compute_dtype(0.0) - fsum_middle = self.compute_dtype(0.0) - for l in range(_q): - if missing_mask[_opp_indices[l]] == wp.uint8(1): - fsum_known += self.compute_dtype(2.0) * fpop[l] - elif missing_mask[l] != wp.uint8(1): - fsum_middle += fpop[l] - return fsum_known + fsum_middle - - @wp.func - def get_normal_vectors( - missing_mask: Any, - ): - if wp.static(_d == 3): - for l in range(_q): - if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: - return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l]) - else: - for l in range(_q): - if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: - return -_u_vec(_c_float[0, l], _c_float[1, l]) - - @wp.func - def bounceback_nonequilibrium( - fpop: Any, - feq: Any, - missing_mask: Any, - ): - for l in range(_q): - if missing_mask[l] == wp.uint8(1): - fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]] - return fpop @wp.func def functional_velocity( @@ -336,10 +290,10 @@ def functional_velocity( _f = _f_post # Find normal vector - normals = get_normal_vectors(_missing_mask) + normals = bc_helper.get_normal_vectors(_missing_mask) # calculate rho - fsum = _get_fsum(_f, _missing_mask) + fsum = bc_helper.get_bc_fsum(_f, _missing_mask) unormal = self.compute_dtype(0.0) # Find the value of u from the missing directions @@ -355,7 +309,7 @@ def functional_velocity( # impose non-equilibrium bounceback _feq = self.equilibrium_operator.warp_functional(_rho, _u) - _f = bounceback_nonequilibrium(_f, _feq, _missing_mask) + _f = bc_helper.bounceback_nonequilibrium(_f, _feq, _missing_mask) return _f @wp.func @@ -372,20 +326,20 @@ def functional_pressure( _f = _f_post # Find normal vector - normals = get_normal_vectors(_missing_mask) + normals = bc_helper.get_normal_vectors(_missing_mask) # Find the value of rho from the missing directions # Since we need only one scalar value, we only need to find one value (stored at the center of f_1) _rho = f_1[0, index[0], index[1], index[2]] # calculate velocity - fsum = _get_fsum(_f, _missing_mask) + fsum = bc_helper.get_bc_fsum(_f, _missing_mask) unormal = -self.compute_dtype(1.0) + fsum / _rho _u = unormal * normals # impose non-equilibrium bounceback feq = self.equilibrium_operator.warp_functional(_rho, _u) - _f = bounceback_nonequilibrium(_f, feq, _missing_mask) + _f = bc_helper.bounceback_nonequilibrium(_f, feq, _missing_mask) return _f if self.bc_type == "velocity": diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 323c75dc..0b8a93c6 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -16,6 +16,7 @@ from xlb.operator.operator import Operator from xlb import DefaultConfig from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry +from xlb.operator.boundary_condition import HelperFunctionsBC # Enum for implementation step @@ -71,53 +72,6 @@ def __init__( # A flag for BCs that need auxilary data recovery after streaming self.needs_aux_recovery = False - if self.compute_backend == ComputeBackend.WARP: - # Set local constants TODO: This is a hack and should be fixed with warp update - _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool - - @wp.func - def update_bc_auxilary_data( - index: Any, - timestep: Any, - missing_mask: Any, - f_0: Any, - f_1: Any, - f_pre: Any, - f_post: Any, - ): - return f_post - - @wp.func - def _get_thread_data( - f_pre: wp.array4d(dtype=Any), - f_post: wp.array4d(dtype=Any), - bc_mask: wp.array4d(dtype=wp.uint8), - missing_mask: wp.array4d(dtype=wp.bool), - index: wp.vec3i, - ): - # Get the boundary id and missing mask - _f_pre = _f_vec() - _f_post = _f_vec() - _boundary_id = bc_mask[0, index[0], index[1], index[2]] - _missing_mask = _missing_mask_vec() - for l in range(self.velocity_set.q): - # q-sized vector of populations - _f_pre[l] = self.compute_dtype(f_pre[l, index[0], index[1], index[2]]) - _f_post[l] = self.compute_dtype(f_post[l, index[0], index[1], index[2]]) - - # TODO fix vec bool - if missing_mask[l, index[0], index[1], index[2]]: - _missing_mask[l] = wp.uint8(1) - else: - _missing_mask[l] = wp.uint8(0) - return _f_pre, _f_post, _boundary_id, _missing_mask - - # Construct some helper warp functions for getting tid data - if self.compute_backend == ComputeBackend.WARP: - self._get_thread_data = _get_thread_data - self.update_bc_auxilary_data = update_bc_auxilary_data - @partial(jit, static_argnums=(0,), inline=True) def update_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask): """ @@ -131,6 +85,7 @@ def _construct_kernel(self, functional): Constructs the warp kernel for the boundary condition. The functional is specific to each boundary condition and should be passed as an argument. """ + bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend) _id = wp.uint8(self.id) # Construct the warp kernel @@ -146,7 +101,7 @@ def kernel( index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data(f_pre, f_post, bc_mask, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = bc_helper.get_thread_data(f_pre, f_post, bc_mask, missing_mask, index) # Apply the boundary condition if _boundary_id == _id: @@ -165,6 +120,8 @@ def _construct_aux_data_init_kernel(self, functional): """ Constructs the warp kernel for the auxilary data recovery. """ + bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend) + _id = wp.uint8(self.id) _opp_indices = self.velocity_set.opp_indices _num_of_aux_data = self.num_of_aux_data @@ -182,7 +139,7 @@ def aux_data_init_kernel( index = wp.vec3i(i, j, k) # read tid data - _f_0, _f_1, _boundary_id, _missing_mask = self._get_thread_data(f_0, f_1, bc_mask, missing_mask, index) + _f_0, _f_1, _boundary_id, _missing_mask = bc_helper.get_thread_data(f_0, f_1, bc_mask, missing_mask, index) # Apply the functional if _boundary_id == _id: diff --git a/xlb/operator/boundary_condition/helper_functions_bc.py b/xlb/operator/boundary_condition/helper_functions_bc.py new file mode 100644 index 00000000..6f8e768b --- /dev/null +++ b/xlb/operator/boundary_condition/helper_functions_bc.py @@ -0,0 +1,128 @@ +from xlb import DefaultConfig, ComputeBackend +from xlb.operator.macroscopic.second_moment import SecondMoment as MomentumFlux +import warp as wp +from typing import Any + + +class HelperFunctionsBC(object): + def __init__(self, velocity_set=None, precision_policy=None, compute_backend=None): + if compute_backend == ComputeBackend.JAX: + raise ValueError("This helper class contains helper functions only for the WARP implementation of some BCs not JAX!") + + # Set the default values from the global config + self.velocity_set = velocity_set or DefaultConfig.velocity_set + self.precision_policy = precision_policy or DefaultConfig.default_precision_policy + self.compute_backend = compute_backend or DefaultConfig.default_backend + + # Set the compute and Store dtypes + compute_dtype = self.precision_policy.compute_precision.wp_dtype + store_dtype = self.precision_policy.store_precision.wp_dtype + + # Set local constants + _d = self.velocity_set.d + _q = self.velocity_set.q + _opp_indices = self.velocity_set.opp_indices + _w = self.velocity_set.w + _c = self.velocity_set.c + _c_float = self.velocity_set.c_float + _qi = self.velocity_set.qi + _u_vec = wp.vec(_d, dtype=compute_dtype) + _f_vec = wp.vec(_q, dtype=compute_dtype) + _missing_mask_vec = wp.vec(_q, dtype=wp.uint8) # TODO fix vec bool + + # Define the operator needed for computing the momentum flux + momentum_flux = MomentumFlux(velocity_set, precision_policy, compute_backend) + + @wp.func + def get_thread_data( + f_pre: wp.array4d(dtype=Any), + f_post: wp.array4d(dtype=Any), + bc_mask: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + index: wp.vec3i, + ): + # Get the boundary id and missing mask + _f_pre = _f_vec() + _f_post = _f_vec() + _boundary_id = bc_mask[0, index[0], index[1], index[2]] + _missing_mask = _missing_mask_vec() + for l in range(_q): + # q-sized vector of populations + _f_pre[l] = compute_dtype(f_pre[l, index[0], index[1], index[2]]) + _f_post[l] = compute_dtype(f_post[l, index[0], index[1], index[2]]) + + # TODO fix vec bool + if missing_mask[l, index[0], index[1], index[2]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + return _f_pre, _f_post, _boundary_id, _missing_mask + + @wp.func + def get_bc_fsum( + fpop: Any, + missing_mask: Any, + ): + fsum_known = compute_dtype(0.0) + fsum_middle = compute_dtype(0.0) + for l in range(_q): + if missing_mask[_opp_indices[l]] == wp.uint8(1): + fsum_known += compute_dtype(2.0) * fpop[l] + elif missing_mask[l] != wp.uint8(1): + fsum_middle += fpop[l] + return fsum_known + fsum_middle + + @wp.func + def get_normal_vectors( + missing_mask: Any, + ): + if wp.static(_d == 3): + for l in range(_q): + if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: + return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l]) + else: + for l in range(_q): + if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: + return -_u_vec(_c_float[0, l], _c_float[1, l]) + + @wp.func + def bounceback_nonequilibrium( + fpop: Any, + feq: Any, + missing_mask: Any, + ): + for l in range(_q): + if missing_mask[l] == wp.uint8(1): + fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]] + return fpop + + @wp.func + def regularize_fpop( + fpop: Any, + feq: Any, + ): + """ + Regularizes the distribution functions by adding non-equilibrium contributions based on second moments of fpop. + """ + # Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq} + f_neq = fpop - feq + PiNeq = momentum_flux.warp_functional(f_neq) + + # Compute double dot product Qi:Pi1 (where Pi1 = PiNeq) + nt = _d * (_d + 1) // 2 + for l in range(_q): + QiPi1 = compute_dtype(0.0) + for t in range(nt): + QiPi1 += _qi[l, t] * PiNeq[t] + + # assign all populations based on eq 45 of Latt et al (2008) + # fneq ~ f^1 + fpop1 = compute_dtype(4.5) * _w[l] * QiPi1 + fpop[l] = feq[l] + fpop1 + return fpop + + self.get_thread_data = get_thread_data + self.get_bc_fsum = get_bc_fsum + self.get_normal_vectors = get_normal_vectors + self.bounceback_nonequilibrium = bounceback_nonequilibrium + self.regularize_fpop = regularize_fpop