Skip to content

Commit

Permalink
Used center of f_1 as an additional storage and also fixed some bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
hsalehipour committed Jan 2, 2025
1 parent c48e0ec commit 23d06d8
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 55 deletions.
4 changes: 2 additions & 2 deletions examples/cfd/flow_past_sphere_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def bc_profile(self):
@wp.func
def bc_profile_warp(index: wp.vec3i):
# Poiseuille flow profile: parabolic velocity distribution
y = self.precision_policy.store_precision.wp_dtype(index[1])
z = self.precision_policy.store_precision.wp_dtype(index[2])
y = wp.float32(index[1])
z = wp.float32(index[2])

# Calculate normalized distance from center
y_center = y - (H_y / 2.0)
Expand Down
18 changes: 6 additions & 12 deletions xlb/operator/boundary_condition/bc_regularized.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,10 @@ def functional_velocity(
normals = get_normal_vectors(missing_mask)

# Find the value of u from the missing directions
for l in range(_q):
# Since we are only considering normal velocity, we only need to find one value
if missing_mask[l] == wp.uint8(1):
# Create velocity vector by multiplying the prescribed value with the normal vector
prescribed_value = f_1[_opp_indices[l], index[0], index[1], index[2]]
_u = -prescribed_value * normals
break
# Since we are only considering normal velocity, we only need to find one value (stored at the center of f_1)
# Create velocity vector by multiplying the prescribed value with the normal vector
prescribed_value = f_1[0, index[0], index[1], index[2]]
_u = -prescribed_value * normals

# calculate rho
fsum = _get_fsum(_f, missing_mask)
Expand Down Expand Up @@ -262,11 +259,8 @@ def functional_pressure(
normals = get_normal_vectors(missing_mask)

# Find the value of rho from the missing directions
for q in range(_q):
# Since we need only one scalar value, we only need to find one value
if missing_mask[q] == wp.uint8(1):
_rho = f_1[_opp_indices[q], index[0], index[1], index[2]]
break
# Since we need only one scalar value, we only need to find one value (stored at the center of f_1)
_rho = f_1[0, index[0], index[1], index[2]]

# calculate velocity
fsum = _get_fsum(_f, missing_mask)
Expand Down
62 changes: 25 additions & 37 deletions xlb/operator/boundary_condition/bc_zouhe.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ def __init__(
if non_zero_count > 1:
raise ValueError("This BC only supports normal prescribed values (only one non-zero element allowed)")

# Prescribed value for this BC must be:
# a single non-zero number associated with normal velocity magnitude for velocity BC OR
# a single non-zero number associated with pressure BC OR
# a vector of zeros associated with no-slip BC.
# Accounting for all scenarios here.
if self.compute_backend is ComputeBackend.WARP:
idx = np.nonzero(prescribed_value)[0]
prescribed_value = prescribed_value[idx][0] if idx.size else 0.0
prescribed_value = self.precision_policy.store_precision.wp_dtype(prescribed_value)
self.prescribed_value = prescribed_value
self.profile = self._create_constant_prescribed_profile()

Expand All @@ -107,28 +116,14 @@ def __init__(
self.needs_padding = True

def _create_constant_prescribed_profile(self):
if self.bc_type == "velocity":

@wp.func
def prescribed_profile_warp(index: wp.vec3i):
# Get the non-zero value from prescribed_value
value = wp.static(
self.precision_policy.store_precision.wp_dtype(float(self.prescribed_value[np.nonzero(self.prescribed_value)[0][0]]))
)
return wp.vec(value, length=1)

def prescribed_profile_jax():
return jnp.array(self.prescribed_value, dtype=self.precision_policy.store_precision.jax_dtype).reshape(-1, 1)
_prescribed_value = self.prescribed_value

else: # pressure

@wp.func
def prescribed_profile_warp(index: wp.vec3i):
value = wp.static(self.precision_policy.store_precision.wp_dtype(self.prescribed_value))
return wp.vec(value, length=1)
@wp.func
def prescribed_profile_warp(index: wp.vec3i):
return wp.vec(_prescribed_value, length=1)

def prescribed_profile_jax():
return jnp.array(self.prescribed_value)
def prescribed_profile_jax():
return jnp.array(_prescribed_value, dtype=self.precision_policy.store_precision.jax_dtype).reshape(-1, 1)

if self.compute_backend == ComputeBackend.JAX:
return prescribed_profile_jax
Expand Down Expand Up @@ -332,8 +327,8 @@ def functional_velocity(
index: Any,
timestep: Any,
_missing_mask: Any,
f_pre: Any,
f_post: Any,
f_0: Any,
f_1: Any,
_f_pre: Any,
_f_post: Any,
):
Expand All @@ -348,14 +343,10 @@ def functional_velocity(
unormal = self.compute_dtype(0.0)

# Find the value of u from the missing directions
for l in range(_q):
# Since we are only considering normal velocity, we only need to find one value (all values are the same in the missing directions)
if _missing_mask[l] == wp.uint8(1):
# Create velocity vector by multiplying the prescribed value with the normal vector
# TODO: This can be optimized by saving _missing_mask[l] in the bc class later since it is the same for all boundary cells
prescribed_value = f_post[_opp_indices[l], index[0], index[1], index[2]]
_u = -prescribed_value * normals
break
# Since we are only considering normal velocity, we only need to find one value (stored at the center of f_1)
# Create velocity vector by multiplying the prescribed value with the normal vector
prescribed_value = f_1[0, index[0], index[1], index[2]]
_u = -prescribed_value * normals

for d in range(_d):
unormal += _u[d] * normals[d]
Expand All @@ -372,8 +363,8 @@ def functional_pressure(
index: Any,
timestep: Any,
_missing_mask: Any,
f_pre: Any,
f_post: Any,
f_0: Any,
f_1: Any,
_f_pre: Any,
_f_post: Any,
):
Expand All @@ -384,11 +375,8 @@ def functional_pressure(
normals = get_normal_vectors(_missing_mask)

# Find the value of rho from the missing directions
for q in range(_q):
# Since we need only one scalar value, we only need to find one value (all values are the same in the missing directions)
if _missing_mask[q] == wp.uint8(1):
_rho = f_post[_opp_indices[q], index[0], index[1], index[2]]
break
# Since we need only one scalar value, we only need to find one value (stored at the center of f_1)
_rho = f_1[0, index[0], index[1], index[2]]

# calculate velocity
fsum = _get_fsum(_f, _missing_mask)
Expand Down
5 changes: 3 additions & 2 deletions xlb/operator/boundary_condition/boundary_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,9 @@ def aux_data_init_kernel(
prescribed_values = functional(index)
# Write the result for all q directions, but only store up to num_of_aux_data
# TODO: Somehow raise an error if the number of prescribed values does not match the number of missing directions
counter = wp.int32(0)
for l in range(self.velocity_set.q):
f_1[0, index[0], index[1], index[2]] = self.store_dtype(prescribed_values[0])
counter = wp.int32(1)
for l in range(1, self.velocity_set.q):
if _missing_mask[l] == wp.uint8(1) and counter < _num_of_aux_data:
f_1[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(prescribed_values[counter])
counter += 1
Expand Down
2 changes: 1 addition & 1 deletion xlb/operator/equilibrium/quadratic_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class QuadraticEquilibrium(Equilibrium):
def jax_implementation(self, rho, u):
cu = 3.0 * jnp.tensordot(self.velocity_set.c, u, axes=(0, 0))
usqr = 1.5 * jnp.sum(jnp.square(u), axis=0, keepdims=True)
w = self.velocity_set.w.reshape((-1,) + (1,) * (len(rho.shape) - 1))
w = self.velocity_set.w.reshape((-1,) + (1,) * self.velocity_set.d)
feq = rho * w * (1.0 + cu * (1.0 + 0.5 * cu) - usqr)
return feq

Expand Down
3 changes: 2 additions & 1 deletion xlb/operator/stepper/nse_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ def apply_aux_recovery_bc(
if wp.static(self.boundary_conditions[i].needs_aux_recovery):
if _boundary_id == wp.static(self.boundary_conditions[i].id):
# Perform the swapping of data
for l in range(self.velocity_set.q):
f_0[0, index[0], index[1], index[2]] = self.store_dtype(_f1_thread[0])
for l in range(1, self.velocity_set.q):
if _missing_mask[l] == wp.uint8(1):
f_0[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(_f1_thread[_opp_indices[l]])

Expand Down

0 comments on commit 23d06d8

Please sign in to comment.