Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moved omega from an attribute of the collision to the input of its callable #100

Merged
merged 1 commit into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/cfd/flow_past_sphere_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy):
self.backend = backend
self.precision_policy = precision_policy
self.omega = omega

self.boundary_conditions = []
self.u_max = 0.04

Expand Down Expand Up @@ -75,7 +76,6 @@ def setup_boundary_conditions(self):

def setup_stepper(self):
self.stepper = IncompressibleNavierStokesStepper(
omega=self.omega,
grid=self.grid,
boundary_conditions=self.boundary_conditions,
collision_type="BGK",
Expand Down Expand Up @@ -127,7 +127,7 @@ def bc_profile_jax():
def run(self, num_steps, post_process_interval=100):
start_time = time.time()
for i in range(num_steps):
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, self.omega, i)
self.f_0, self.f_1 = self.f_1, self.f_0

if i % post_process_interval == 0 or i == num_steps - 1:
Expand Down
3 changes: 1 addition & 2 deletions examples/cfd/lid_driven_cavity_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,14 @@ def setup_boundary_conditions(self):

def setup_stepper(self):
self.stepper = IncompressibleNavierStokesStepper(
omega=self.omega,
grid=self.grid,
boundary_conditions=self.boundary_conditions,
collision_type="BGK",
)

def run(self, num_steps, post_process_interval=100):
for i in range(num_steps):
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, self.omega, i)
self.f_0, self.f_1 = self.f_1, self.f_0

if i % post_process_interval == 0 or i == num_steps - 1:
Expand Down
1 change: 0 additions & 1 deletion examples/cfd/lid_driven_cavity_2d_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, pre
def setup_stepper(self):
# Create the base stepper
stepper = IncompressibleNavierStokesStepper(
omega=self.omega,
grid=self.grid,
boundary_conditions=self.boundary_conditions,
collision_type="BGK",
Expand Down
3 changes: 1 addition & 2 deletions examples/cfd/turbulent_channel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def initialize_fields(self):

def setup_stepper(self):
self.stepper = IncompressibleNavierStokesStepper(
omega=self.omega,
grid=self.grid,
boundary_conditions=self.boundary_conditions,
collision_type="KBC",
Expand All @@ -108,7 +107,7 @@ def setup_stepper(self):
def run(self, num_steps, print_interval, post_process_interval=100):
start_time = time.time()
for i in range(num_steps):
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, self.omega, i)
self.f_0, self.f_1 = self.f_1, self.f_0

if (i + 1) % print_interval == 0:
Expand Down
3 changes: 1 addition & 2 deletions examples/cfd/windtunnel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def setup_boundary_conditions(self):

def setup_stepper(self):
self.stepper = IncompressibleNavierStokesStepper(
omega=self.omega,
grid=self.grid,
boundary_conditions=self.boundary_conditions,
collision_type="KBC",
Expand All @@ -111,7 +110,7 @@ def run(self, num_steps, print_interval, post_process_interval=100):

start_time = time.time()
for i in range(num_steps):
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, self.omega, i)
self.f_0, self.f_1 = self.f_1, self.f_0

if (i + 1) % print_interval == 0:
Expand Down
5 changes: 3 additions & 2 deletions examples/performance/mlups_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def run(backend, precision_policy, grid_shape, num_steps):
boundary_conditions = [EquilibriumBC(rho=1.0, u=(0.02, 0.0, 0.0), indices=lid), FullwayBounceBackBC(indices=walls)]

# Create stepper
stepper = IncompressibleNavierStokesStepper(omega=1.0, grid=grid, boundary_conditions=boundary_conditions, collision_type="BGK")
stepper = IncompressibleNavierStokesStepper(grid=grid, boundary_conditions=boundary_conditions, collision_type="BGK")

# Distribute if using JAX backend
if backend == ComputeBackend.JAX:
Expand All @@ -64,11 +64,12 @@ def run(backend, precision_policy, grid_shape, num_steps):
)

# Initialize fields and run simulation
omega = 1.0
f_0, f_1, bc_mask, missing_mask = stepper.prepare_fields()
start_time = time.time()

for i in range(num_steps):
f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, i)
f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, omega, i)
f_0, f_1 = f_1, f_0
wp.synchronize()

Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/collision/test_bgk_collision_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def test_bgk_ollision(dim, velocity_set, grid_shape, omega):

# Compute collision

compute_collision = BGK(omega=omega)
compute_collision = BGK()

f_orig = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q)

f_out = compute_collision(f_orig, f_eq, rho, u)
f_out = compute_collision(f_orig, f_eq, rho, u, omega)

assert jnp.allclose(f_out, f_orig - omega * (f_orig - f_eq))

Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/collision/test_bgk_collision_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ def test_bgk_collision_warp(dim, velocity_set, grid_shape, omega):
f_eq = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q)
f_eq = compute_macro(rho, u, f_eq)

compute_collision = BGK(omega=omega)
compute_collision = BGK()
f_orig = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q)

f_out = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q)
f_out = compute_collision(f_orig, f_eq, f_out, rho, u)
f_out = compute_collision(f_orig, f_eq, f_out, rho, u, omega)

f_eq = f_eq.numpy()
f_out = f_out.numpy()
Expand Down
2 changes: 1 addition & 1 deletion xlb/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def check_backend_support():
elif len(gpus) == 1:
print("Single-GPU support is available: 1 GPU detected.")

if jax.devices()[0].platform == "tpu":
elif jax.devices()[0].platform == "tpu":
tpus = jax.devices("tpu")
if len(tpus) > 1:
print("Multi-TPU support is available: {} TPUs detected.".format(len(tpus)))
Expand Down
15 changes: 8 additions & 7 deletions xlb/operator/collision/bgk.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,21 @@ class BGK(Collision):

@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0,))
def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho, u):
def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho, u, omega):
fneq = f - feq
fout = f - self.compute_dtype(self.omega) * fneq
fout = f - self.compute_dtype(omega) * fneq
return fout

def _construct_warp(self):
# Set local constants TODO: This is a hack and should be fixed with warp update
_w = self.velocity_set.w
_omega = wp.constant(self.compute_dtype(self.omega))
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)

# Construct the functional
@wp.func
def functional(f: Any, feq: Any, rho: Any, u: Any):
def functional(f: Any, feq: Any, rho: Any, u: Any, omega: Any):
fneq = f - feq
fout = f - _omega * fneq
fout = f - self.compute_dtype(omega) * fneq
return fout

# Construct the warp kernel
Expand All @@ -42,6 +41,7 @@ def kernel(
fout: wp.array4d(dtype=Any),
rho: wp.array4d(dtype=Any),
u: wp.array4d(dtype=Any),
omega: Any,
):
# Get the global index
i, j, k = wp.tid()
Expand All @@ -55,7 +55,7 @@ def kernel(
_feq[l] = feq[l, index[0], index[1], index[2]]

# Compute the collision
_fout = functional(_f, _feq, rho, u)
_fout = functional(_f, _feq, rho, u, omega)

# Write the result
for l in range(self.velocity_set.q):
Expand All @@ -64,7 +64,7 @@ def kernel(
return functional, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f, feq, fout, rho, u):
def warp_implementation(self, f, feq, fout, rho, u, omega):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
Expand All @@ -74,6 +74,7 @@ def warp_implementation(self, f, feq, fout, rho, u):
fout,
rho,
u,
omega,
],
dim=f.shape[1:],
)
Expand Down
9 changes: 0 additions & 9 deletions xlb/operator/collision/collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,12 @@ class Collision(Operator):
Base class for collision operators.

This class defines the collision step for the Lattice Boltzmann Method.

Parameters
----------
omega : float
Relaxation parameter for collision step. Default value is 0.6.
shear : bool
Flag to indicate whether the collision step requires the shear stress.
"""

def __init__(
self,
omega: float,
velocity_set: VelocitySet = None,
precision_policy=None,
compute_backend=None,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the extra line after ):

self.omega = omega
super().__init__(velocity_set, precision_policy, compute_backend)
16 changes: 9 additions & 7 deletions xlb/operator/collision/forced_collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
):
assert collision_operator is not None
self.collision_operator = collision_operator
super().__init__(self.collision_operator.omega)
super().__init__()

assert forcing_scheme == "exact_difference", NotImplementedError(f"Force model {forcing_scheme} not implemented!")
assert force_vector.shape[0] == self.velocity_set.d, "Check the dimensions of the input force!"
Expand All @@ -33,8 +33,8 @@ def __init__(

@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0,))
def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho, u):
fout = self.collision_operator(f, feq, rho, u)
def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho, u, omega):
fout = self.collision_operator(f, feq, rho, u, omega)
fout = self.forcing_operator(fout, feq, rho, u)
return fout

Expand All @@ -45,8 +45,8 @@ def _construct_warp(self):

# Construct the functional
@wp.func
def functional(f: Any, feq: Any, rho: Any, u: Any):
fout = self.collision_operator.warp_functional(f, feq, rho, u)
def functional(f: Any, feq: Any, rho: Any, u: Any, omega: Any):
fout = self.collision_operator.warp_functional(f, feq, rho, u, omega)
fout = self.forcing_operator.warp_functional(fout, feq, rho, u)
return fout

Expand All @@ -58,6 +58,7 @@ def kernel(
fout: wp.array4d(dtype=Any),
rho: wp.array4d(dtype=Any),
u: wp.array4d(dtype=Any),
omega: Any,
):
# Get the global index
i, j, k = wp.tid()
Expand All @@ -76,7 +77,7 @@ def kernel(
_rho = rho[0, index[0], index[1], index[2]]

# Compute the collision
_fout = functional(_f, _feq, _rho, _u)
_fout = functional(_f, _feq, _rho, _u, omega)

# Write the result
for l in range(self.velocity_set.q):
Expand All @@ -85,7 +86,7 @@ def kernel(
return functional, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f, feq, fout, rho, u):
def warp_implementation(self, f, feq, fout, rho, u, omega):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
Expand All @@ -95,6 +96,7 @@ def warp_implementation(self, f, feq, fout, rho, u):
fout,
rho,
u,
omega,
],
dim=f.shape[1:],
)
Expand Down
Loading
Loading