From d91f8e9fb96d1c1db9b203ca8dde942435aeebf3 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Thu, 16 Jan 2025 18:04:20 -0500 Subject: [PATCH] Converted examples to functional. Made compute_backend name consistent --- examples/cfd/flow_past_sphere_3d.py | 259 +++++------ examples/cfd/lid_driven_cavity_2d.py | 22 +- .../cfd/lid_driven_cavity_2d_distributed.py | 12 +- examples/cfd/turbulent_channel_3d.py | 342 +++++++------- examples/cfd/windtunnel_3d.py | 429 ++++++++++-------- examples/performance/mlups_3d.py | 76 ++-- .../bc_equilibrium/test_bc_equilibrium_jax.py | 2 +- .../test_bc_equilibrium_warp.py | 2 +- .../test_bc_fullway_bounce_back_jax.py | 2 +- .../test_bc_fullway_bounce_back_warp.py | 2 +- .../mask/test_bc_indices_masker_jax.py | 2 +- .../mask/test_bc_indices_masker_warp.py | 2 +- tests/grids/test_grid_jax.py | 2 +- tests/grids/test_grid_warp.py | 2 +- .../collision/test_bgk_collision_jax.py | 2 +- .../collision/test_bgk_collision_warp.py | 2 +- .../equilibrium/test_equilibrium_jax.py | 2 +- .../equilibrium/test_equilibrium_warp.py | 2 +- .../macroscopic/test_macroscopic_jax.py | 2 +- .../macroscopic/test_macroscopic_warp.py | 2 +- tests/kernels/stream/test_stream_jax.py | 2 +- tests/kernels/stream/test_stream_warp.py | 2 +- xlb/grid/jax_grid.py | 2 +- xlb/helper/check_boundary_overlaps.py | 6 +- xlb/helper/initializers.py | 6 +- xlb/operator/operator.py | 21 +- xlb/velocity_set/d2q9.py | 4 +- xlb/velocity_set/d3q19.py | 4 +- xlb/velocity_set/d3q27.py | 4 +- xlb/velocity_set/velocity_set.py | 17 +- 30 files changed, 647 insertions(+), 589 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 00a6aabb..b74128ae 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -16,95 +16,54 @@ import jax.numpy as jnp import time +# -------------------------- Simulation Setup -------------------------- + +omega = 1.6 +grid_shape = (512 // 2, 128 // 2, 128 // 2) +compute_backend = ComputeBackend.WARP +precision_policy = PrecisionPolicy.FP32FP32 +velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, compute_backend=compute_backend) +u_max = 0.04 +num_steps = 10000 +post_process_interval = 1000 + +# Initialize XLB +xlb.init( + velocity_set=velocity_set, + default_backend=compute_backend, + default_precision_policy=precision_policy, +) -class FlowOverSphere: - def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): - # initialize backend - xlb.init( - velocity_set=velocity_set, - default_backend=backend, - default_precision_policy=precision_policy, - ) - - self.grid_shape = grid_shape - self.velocity_set = velocity_set - self.backend = backend - self.precision_policy = precision_policy - self.omega = omega - - self.boundary_conditions = [] - self.u_max = 0.04 - - # Create grid using factory - self.grid = grid_factory(grid_shape, compute_backend=backend) - - # Setup the simulation BC and stepper - self._setup() - - def _setup(self): - self.setup_boundary_conditions() - self.setup_stepper() - - def define_boundary_indices(self): - box = self.grid.bounding_box_indices() - box_no_edge = self.grid.bounding_box_indices(remove_edges=True) - inlet = box_no_edge["left"] - outlet = box_no_edge["right"] - walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(self.velocity_set.d)] - walls = np.unique(np.array(walls), axis=-1).tolist() - - sphere_radius = self.grid_shape[1] // 12 - x = np.arange(self.grid_shape[0]) - y = np.arange(self.grid_shape[1]) - z = np.arange(self.grid_shape[2]) - X, Y, Z = np.meshgrid(x, y, z, indexing="ij") - indices = np.where( - (X - self.grid_shape[0] // 6) ** 2 + (Y - self.grid_shape[1] // 2) ** 2 + (Z - self.grid_shape[2] // 2) ** 2 < sphere_radius**2 - ) - sphere = [tuple(indices[i]) for i in range(self.velocity_set.d)] - - return inlet, outlet, walls, sphere - - def setup_boundary_conditions(self): - inlet, outlet, walls, sphere = self.define_boundary_indices() - bc_left = RegularizedBC("velocity", profile=self.bc_profile(), indices=inlet) - # bc_left = RegularizedBC("velocity", prescribed_value=(self.u_max, 0.0, 0.0), indices=inlet) - bc_walls = FullwayBounceBackBC(indices=walls) - bc_outlet = ExtrapolationOutflowBC(indices=outlet) - bc_sphere = HalfwayBounceBackBC(indices=sphere) - self.boundary_conditions = [bc_walls, bc_left, bc_outlet, bc_sphere] - - def setup_stepper(self): - self.stepper = IncompressibleNavierStokesStepper( - grid=self.grid, - boundary_conditions=self.boundary_conditions, - collision_type="BGK", - ) - self.f_0, self.f_1, self.bc_mask, self.missing_mask = self.stepper.prepare_fields() - - def bc_profile(self): - u_max = self.u_max # u_max = 0.04 - # Get the grid dimensions for the y and z directions - H_y = float(self.grid_shape[1] - 1) # Height in y direction - H_z = float(self.grid_shape[2] - 1) # Height in z direction +# Create Grid +grid = grid_factory(grid_shape, compute_backend=compute_backend) - @wp.func - def bc_profile_warp(index: wp.vec3i): - # Poiseuille flow profile: parabolic velocity distribution - y = wp.float32(index[1]) - z = wp.float32(index[2]) +# Define Boundary Indices +box = grid.bounding_box_indices() +box_no_edge = grid.bounding_box_indices(remove_edges=True) +inlet = box_no_edge["left"] +outlet = box_no_edge["right"] +walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(velocity_set.d)] +walls = np.unique(np.array(walls), axis=-1).tolist() - # Calculate normalized distance from center - y_center = y - (H_y / 2.0) - z_center = z - (H_z / 2.0) - r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0 +sphere_radius = grid_shape[1] // 12 +x = np.arange(grid_shape[0]) +y = np.arange(grid_shape[1]) +z = np.arange(grid_shape[2]) +X, Y, Z = np.meshgrid(x, y, z, indexing="ij") +indices = np.where((X - grid_shape[0] // 6) ** 2 + (Y - grid_shape[1] // 2) ** 2 + (Z - grid_shape[2] // 2) ** 2 < sphere_radius**2) +sphere = [tuple(indices[i]) for i in range(velocity_set.d)] - # Parabolic profile: u = u_max * (1 - r²) - return wp.vec(u_max * wp.max(0.0, 1.0 - r_squared), length=1) + +# Define Boundary Conditions +def bc_profile(): + H_y = float(grid_shape[1] - 1) # Height in y direction + H_z = float(grid_shape[2] - 1) # Height in z direction + + if compute_backend == ComputeBackend.JAX: def bc_profile_jax(): - y = jnp.arange(self.grid_shape[1]) - z = jnp.arange(self.grid_shape[2]) + y = jnp.arange(grid_shape[1]) + z = jnp.arange(grid_shape[2]) Y, Z = jnp.meshgrid(y, z, indexing="ij") # Calculate normalized distance from center @@ -119,56 +78,88 @@ def bc_profile_jax(): return jnp.stack([u_x, u_y, u_z]) - if self.backend == ComputeBackend.JAX: - return bc_profile_jax - elif self.backend == ComputeBackend.WARP: - return bc_profile_warp + return bc_profile_jax + + elif compute_backend == ComputeBackend.WARP: + + @wp.func + def bc_profile_warp(index: wp.vec3i): + # Poiseuille flow profile: parabolic velocity distribution + y = wp.float32(index[1]) + z = wp.float32(index[2]) + + # Calculate normalized distance from center + y_center = y - (H_y / 2.0) + z_center = z - (H_z / 2.0) + r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0 + + # Parabolic profile: u = u_max * (1 - r²) + return wp.vec(u_max * wp.max(0.0, 1.0 - r_squared), length=1) + + return bc_profile_warp + + +# Initialize Boundary Conditions +bc_left = RegularizedBC("velocity", profile=bc_profile(), indices=inlet) +# Alternatively, use a prescribed velocity profile +# bc_left = RegularizedBC("velocity", prescribed_value=(u_max, 0.0, 0.0), indices=inlet) +bc_walls = FullwayBounceBackBC(indices=walls) +bc_outlet = ExtrapolationOutflowBC(indices=outlet) +bc_sphere = HalfwayBounceBackBC(indices=sphere) +boundary_conditions = [bc_walls, bc_left, bc_outlet, bc_sphere] + +# Setup Stepper +stepper = IncompressibleNavierStokesStepper( + grid=grid, + boundary_conditions=boundary_conditions, + collision_type="BGK", +) +f_0, f_1, bc_mask, missing_mask = stepper.prepare_fields() + +# Define Macroscopic Calculation +macro = Macroscopic( + compute_backend=ComputeBackend.JAX, + precision_policy=precision_policy, + velocity_set=xlb.velocity_set.D3Q19(precision_policy=precision_policy, compute_backend=ComputeBackend.JAX), +) + + +# Post-Processing Function +def post_process(step, f_current): + # Convert to JAX array if necessary + if not isinstance(f_current, jnp.ndarray): + f_current = wp.to_jax(f_current) + + rho, u = macro(f_current) + + # Remove boundary cells + u = u[:, 1:-1, 1:-1, 1:-1] + rho = rho[:, 1:-1, 1:-1, 1:-1] + u_magnitude = jnp.sqrt(u[0] ** 2 + u[1] ** 2 + u[2] ** 2) + + fields = { + "u_magnitude": u_magnitude, + "u_x": u[0], + "u_y": u[1], + "u_z": u[2], + "rho": rho[0], + } + + # Save the u_magnitude slice at the mid y-plane + save_image(fields["u_magnitude"][:, grid_shape[1] // 2, :], timestep=step) + print(f"Post-processed step {step}: Saved u_magnitude slice at y={grid_shape[1] // 2}") + + +# -------------------------- Simulation Loop -------------------------- + +start_time = time.time() +for step in range(num_steps): + f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, omega, step) + f_0, f_1 = f_1, f_0 # Swap the buffers - def run(self, num_steps, post_process_interval=100): + if step % post_process_interval == 0 or step == num_steps - 1: + post_process(step, f_0) + end_time = time.time() + elapsed = end_time - start_time + print(f"Completed step {step}. Time elapsed for {post_process_interval} steps: {elapsed:.6f} seconds.") start_time = time.time() - for i in range(num_steps): - self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, self.omega, i) - self.f_0, self.f_1 = self.f_1, self.f_0 - - if i % post_process_interval == 0 or i == num_steps - 1: - self.post_process(i) - end_time = time.time() - print(f"Completing {i} iterations. Time elapsed for 1000 LBM steps in {end_time - start_time:.6f} seconds.") - start_time = time.time() - - def post_process(self, i): - # Write the results. We'll use JAX backend for the post-processing - if not isinstance(self.f_0, jnp.ndarray): - f_0 = wp.to_jax(self.f_0) - else: - f_0 = self.f_0 - - macro = Macroscopic( - compute_backend=ComputeBackend.JAX, - precision_policy=self.precision_policy, - velocity_set=xlb.velocity_set.D3Q19(precision_policy=self.precision_policy, backend=ComputeBackend.JAX), - ) - rho, u = macro(f_0) - - # remove boundary cells - u = u[:, 1:-1, 1:-1, 1:-1] - rho = rho[:, 1:-1, 1:-1, 1:-1] - u_magnitude = (u[0] ** 2 + u[1] ** 2 + u[2] ** 2) ** 0.5 - - fields = {"u_magnitude": u_magnitude, "u_x": u[0], "u_y": u[1], "u_z": u[2], "rho": rho[0]} - - # save_fields_vtk(fields, timestep=i) - save_image(fields["u_magnitude"][:, self.grid_shape[1] // 2, :], timestep=i) - - -if __name__ == "__main__": - # Running the simulation - grid_shape = (512 // 2, 128 // 2, 128 // 2) - backend = ComputeBackend.WARP - precision_policy = PrecisionPolicy.FP32FP32 - - velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend) - omega = 1.6 - - simulation = FlowOverSphere(omega, grid_shape, velocity_set, backend, precision_policy) - simulation.run(num_steps=10000, post_process_interval=1000) diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index b681b996..0ca20033 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -13,24 +13,24 @@ class LidDrivenCavity2D: - def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy): - # initialize backend + def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, compute_backend, precision_policy): + # initialize compute_backend xlb.init( velocity_set=velocity_set, - default_backend=backend, + default_backend=compute_backend, default_precision_policy=precision_policy, ) self.grid_shape = grid_shape self.velocity_set = velocity_set - self.backend = backend + self.compute_backend = compute_backend self.precision_policy = precision_policy self.omega = omega self.boundary_conditions = [] self.prescribed_vel = prescribed_vel # Create grid using factory - self.grid = grid_factory(grid_shape, compute_backend=backend) + self.grid = grid_factory(grid_shape, compute_backend=compute_backend) # Setup the simulation BC and stepper self._setup() @@ -71,9 +71,9 @@ def run(self, num_steps, post_process_interval=100): self.post_process(i) def post_process(self, i): - # Write the results. We'll use JAX backend for the post-processing + # Write the results. We'll use JAX compute_backend for the post-processing if not isinstance(self.f_0, jnp.ndarray): - # If the backend is warp, we need to drop the last dimension added by warp for 2D simulations + # If the compute_backend is warp, we need to drop the last dimension added by warp for 2D simulations f_0 = wp.to_jax(self.f_0)[..., 0] else: f_0 = self.f_0 @@ -81,7 +81,7 @@ def post_process(self, i): macro = Macroscopic( compute_backend=ComputeBackend.JAX, precision_policy=self.precision_policy, - velocity_set=xlb.velocity_set.D2Q9(precision_policy=self.precision_policy, backend=ComputeBackend.JAX), + velocity_set=xlb.velocity_set.D2Q9(precision_policy=self.precision_policy, compute_backend=ComputeBackend.JAX), ) rho, u = macro(f_0) @@ -101,10 +101,10 @@ def post_process(self, i): # Running the simulation grid_size = 500 grid_shape = (grid_size, grid_size) - backend = ComputeBackend.WARP + compute_backend = ComputeBackend.WARP precision_policy = PrecisionPolicy.FP32FP32 - velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend) + velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, compute_backend=compute_backend) # Setting fluid viscosity and relaxation parameter. Re = 200.0 @@ -113,5 +113,5 @@ def post_process(self, i): visc = prescribed_vel * clength / Re omega = 1.0 / (3.0 * visc + 0.5) - simulation = LidDrivenCavity2D(omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy) + simulation = LidDrivenCavity2D(omega, prescribed_vel, grid_shape, velocity_set, compute_backend, precision_policy) simulation.run(num_steps=50000, post_process_interval=1000) diff --git a/examples/cfd/lid_driven_cavity_2d_distributed.py b/examples/cfd/lid_driven_cavity_2d_distributed.py index d06d314a..06a2822a 100644 --- a/examples/cfd/lid_driven_cavity_2d_distributed.py +++ b/examples/cfd/lid_driven_cavity_2d_distributed.py @@ -7,8 +7,8 @@ class LidDrivenCavity2D_distributed(LidDrivenCavity2D): - def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy): - super().__init__(omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy) + def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, compute_backend, precision_policy): + super().__init__(omega, prescribed_vel, grid_shape, velocity_set, compute_backend, precision_policy) def setup_stepper(self): # Create the base stepper @@ -30,10 +30,12 @@ def setup_stepper(self): # Running the simulation grid_size = 512 grid_shape = (grid_size, grid_size) - backend = ComputeBackend.JAX # Must be JAX for distributed multi-GPU computations. Distributed computations on WARP are not supported yet! + compute_backend = ( + ComputeBackend.JAX + ) # Must be JAX for distributed multi-GPU computations. Distributed computations on WARP are not supported yet! precision_policy = PrecisionPolicy.FP32FP32 - velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend) + velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, compute_backend=compute_backend) # Setting fluid viscosity and relaxation parameter. Re = 200.0 @@ -42,5 +44,5 @@ def setup_stepper(self): visc = prescribed_vel * clength / Re omega = 1.0 / (3.0 * visc + 0.5) - simulation = LidDrivenCavity2D_distributed(omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy) + simulation = LidDrivenCavity2D_distributed(omega, prescribed_vel, grid_shape, velocity_set, compute_backend, precision_policy) simulation.run(num_steps=50000, post_process_interval=1000) diff --git a/examples/cfd/turbulent_channel_3d.py b/examples/cfd/turbulent_channel_3d.py index 48dddc29..d9773ecf 100644 --- a/examples/cfd/turbulent_channel_3d.py +++ b/examples/cfd/turbulent_channel_3d.py @@ -15,7 +15,9 @@ import json -# helper functions for this benchmark example +# -------------------------- Helper Functions -------------------------- + + def vonKarman_loglaw_wall(yplus): vonKarmanConst = 0.41 cplus = 5.5 @@ -34,167 +36,179 @@ def get_dns_data(): return json.load(file) -class TurbulentChannel3D: - def __init__(self, channel_half_width, Re_tau, u_tau, grid_shape, velocity_set, backend, precision_policy): - # initialize backend - xlb.init( - velocity_set=velocity_set, - default_backend=backend, - default_precision_policy=precision_policy, - ) - - self.channel_half_width = channel_half_width - self.Re_tau = Re_tau - self.u_tau = u_tau - self.visc = u_tau * channel_half_width / Re_tau - self.omega = 1.0 / (3.0 * self.visc + 0.5) - self.grid_shape = grid_shape - self.velocity_set = velocity_set - self.backend = backend - self.precision_policy = precision_policy - self.boundary_conditions = [] - - # Create grid using factory - self.grid = grid_factory(grid_shape, compute_backend=backend) - - # Setup the simulation BC and stepper - self._setup() - - def get_force(self): - # define the external force - shape = (self.velocity_set.d,) - force = np.zeros(shape) - force[0] = self.Re_tau**2 * self.visc**2 / self.channel_half_width**3 - return force - - def _setup(self): - self.setup_boundary_conditions() - self.setup_stepper() - # Initialize fields using the stepper - self.f_0, self.f_1, self.bc_mask, self.missing_mask = self.stepper.prepare_fields() - self.initialize_fields() - - def define_boundary_indices(self): - # top and bottom sides of the channel are no-slip and the other directions are periodic - box = self.grid.bounding_box_indices(remove_edges=True) - walls = [box["bottom"][i] + box["top"][i] for i in range(self.velocity_set.d)] - return walls - - def setup_boundary_conditions(self): - walls = self.define_boundary_indices() - bc_walls = RegularizedBC("velocity", prescribed_value=(0.0, 0.0, 0.0), indices=walls) - self.boundary_conditions = [bc_walls] - - def initialize_fields(self): - # Initialize with random velocity field - shape = (self.velocity_set.d,) + self.grid_shape - np.random.seed(0) - u_init = np.random.random(shape) - if self.backend == ComputeBackend.JAX: - u_init = jnp.full(shape=shape, fill_value=1e-2 * u_init) - else: - u_init = wp.array(1e-2 * u_init, dtype=self.precision_policy.compute_precision.wp_dtype) - self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend, u=u_init) - - def setup_stepper(self): - self.stepper = IncompressibleNavierStokesStepper( - grid=self.grid, - boundary_conditions=self.boundary_conditions, - collision_type="KBC", - force_vector=self.get_force(), - ) - - def run(self, num_steps, print_interval, post_process_interval=100): +# -------------------------- Simulation Setup -------------------------- + +# Channel Parameter +channel_half_width = 50 + +# Define channel geometry based on h +grid_size_x = 6 * channel_half_width +grid_size_y = 3 * channel_half_width +grid_size_z = 2 * channel_half_width + +# Grid parameters +grid_shape = (grid_size_x, grid_size_y, grid_size_z) + +# Define flow regime +Re_tau = 180 +u_tau = 0.001 + +# Compute viscosity and relaxation parameter omega +visc = u_tau * channel_half_width / Re_tau +omega = 1.0 / (3.0 * visc + 0.5) + +# Runtime & compute_backend configurations +compute_backend = ComputeBackend.WARP +precision_policy = PrecisionPolicy.FP64FP64 +velocity_set = xlb.velocity_set.D3Q27(precision_policy=precision_policy, compute_backend=compute_backend) +num_steps = 10000000 +print_interval = 100000 +post_process_interval = 100000 + +# Print simulation info +print("\n" + "=" * 50 + "\n") +print("Simulation Configuration:") +print(f"Grid size: {grid_size_x} x {grid_size_y} x {grid_size_z}") +print(f"Backend: {compute_backend}") +print(f"Velocity set: {velocity_set}") +print(f"Precision policy: {precision_policy}") +print(f"Reynolds number: {Re_tau}") +print(f"Max iterations: {num_steps}") +print("\n" + "=" * 50 + "\n") + +# Initialize XLB +xlb.init( + velocity_set=velocity_set, + default_backend=compute_backend, + default_precision_policy=precision_policy, +) + +# Create Grid +grid = grid_factory(grid_shape, compute_backend=compute_backend) + + +# Define Force Vector +def get_force(Re_tau, visc, channel_half_width, velocity_set): + shape = (velocity_set.d,) + force = np.zeros(shape) + force[0] = Re_tau**2 * visc**2 / channel_half_width**3 + return force + + +force_vector = get_force(Re_tau, visc, channel_half_width, velocity_set) + + +# Define Boundary Indices +box = grid.bounding_box_indices(remove_edges=True) +walls = [box["bottom"][i] + box["top"][i] for i in range(velocity_set.d)] + +# Define Boundary Conditions +def setup_boundary_conditions(walls, velocity_set, precision_policy): + # No-slip boundary condition: velocity = (0, 0, 0) + bc_walls = RegularizedBC("velocity", prescribed_value=(0.0, 0.0, 0.0), indices=walls) + boundary_conditions = [bc_walls] + return boundary_conditions + + +boundary_conditions = setup_boundary_conditions(walls, velocity_set, precision_policy) + +# Setup Stepper +stepper = IncompressibleNavierStokesStepper( + grid=grid, + boundary_conditions=boundary_conditions, + collision_type="KBC", + force_vector=force_vector, +) + +# Prepare Fields +f_0, f_1, bc_mask, missing_mask = stepper.prepare_fields() + + +# Initialize Fields with Random Velocity +shape = (velocity_set.d,) + grid.shape +np.random.seed(0) +u_init = np.random.random(shape) +if compute_backend == ComputeBackend.JAX: + u_init = jnp.full(shape=shape, fill_value=1e-2 * u_init) +else: + u_init = wp.array(1e-2 * u_init, dtype=precision_policy.compute_precision.wp_dtype) + +f_0 = initialize_eq(f_0, grid, velocity_set, precision_policy, compute_backend, u=u_init) + +# Define Macroscopic Calculation +macro = Macroscopic( + compute_backend=ComputeBackend.JAX, + precision_policy=precision_policy, + velocity_set=xlb.velocity_set.D3Q27(precision_policy=precision_policy, compute_backend=ComputeBackend.JAX), +) + + +# Post-Processing Function +def post_process(step, f_current, grid_shape, macro): + # Convert to JAX array if necessary + if not isinstance(f_current, jnp.ndarray): + f_current = wp.to_jax(f_current) + + rho, u = macro(f_current) + + # Compute velocity magnitude + u_magnitude = jnp.sqrt(u[0] ** 2 + u[1] ** 2 + u[2] ** 2) + fields = { + "rho": rho[0], + "u_x": u[0], + "u_y": u[1], + "u_z": u[2], + "u_magnitude": u_magnitude, + } + + # Save the fields in VTK format + save_fields_vtk(fields, timestep=step) + + # Save the u_magnitude slice at the mid y-plane + mid_y = grid_shape[1] // 2 + save_image(fields["u_magnitude"][:, mid_y, :], timestep=step) + + # Save monitor plot + plot_uplus(u, step, grid_shape, u_tau, visc) + + +# Plotting Function +def plot_uplus(u, timestep, grid_shape, u_tau, visc): + # Mean streamwise velocity in wall units u^+(z) + zz = np.arange(grid_shape[-1]) + zz = np.minimum(zz, zz.max() - zz) + yplus = zz * u_tau / visc + uplus = np.mean(u[0], axis=(0, 1)) / u_tau + uplus_loglaw = vonKarman_loglaw_wall(yplus) + dns_dic = get_dns_data() + + plt.clf() + plt.semilogx(yplus, uplus, "r.", label="Simulation") + plt.semilogx(yplus, uplus_loglaw, "k:", label="Von Karman Log Law") + plt.semilogx(dns_dic["y+"], dns_dic["Umean"], "b-", label="DNS Data") + ax = plt.gca() + ax.set_xlim([0.1, 300]) + ax.set_ylim([0, 20]) + plt.xlabel("y+") + plt.ylabel("U+") + plt.title(f"u+ vs y+ at timestep {timestep}") + plt.legend() + fname = f"uplus_{str(timestep).zfill(5)}.png" + plt.savefig(fname, format="png") + plt.close() + + +# -------------------------- Simulation Loop -------------------------- + +start_time = time.time() +for step in range(num_steps): + f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, omega, step) + f_0, f_1 = f_1, f_0 # Swap the buffers + + if (step + 1) % print_interval == 0: + elapsed_time = time.time() - start_time + print(f"Iteration: {step + 1}/{num_steps} | Time elapsed: {elapsed_time:.2f}s") start_time = time.time() - for i in range(num_steps): - self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, self.omega, i) - self.f_0, self.f_1 = self.f_1, self.f_0 - - if (i + 1) % print_interval == 0: - elapsed_time = time.time() - start_time - print(f"Iteration: {i + 1}/{num_steps} | Time elapsed: {elapsed_time:.2f}s") - - if i % post_process_interval == 0 or i == num_steps - 1: - self.post_process(i) - - def post_process(self, i): - # Write the results. We'll use JAX backend for the post-processing - if not isinstance(self.f_0, jnp.ndarray): - f_0 = wp.to_jax(self.f_0) - else: - f_0 = self.f_0 - - macro = Macroscopic( - compute_backend=ComputeBackend.JAX, - precision_policy=self.precision_policy, - velocity_set=xlb.velocity_set.D3Q27(precision_policy=self.precision_policy, backend=ComputeBackend.JAX), - ) - - rho, u = macro(f_0) - - # compute velocity magnitude - u_magnitude = (u[0] ** 2 + u[1] ** 2 + u[2] ** 2) ** 0.5 - fields = {"rho": rho[0], "u_x": u[0], "u_y": u[1], "u_z": u[2], "u_magnitude": u_magnitude} - save_fields_vtk(fields, timestep=i) - save_image(fields["u_magnitude"][:, self.grid_shape[1] // 2, :], timestep=i) - - # Save monitor plot - self.plot_uplus(u, i) - - def plot_uplus(self, u, timestep): - # mean streamwise velocity in wall units u^+(z) - # Wall distance in wall units to be used inside output_data - zz = np.arange(self.grid_shape[-1]) - zz = np.minimum(zz, zz.max() - zz) - yplus = zz * self.u_tau / self.visc - uplus = np.mean(u[0], axis=(0, 1)) / self.u_tau - uplus_loglaw = vonKarman_loglaw_wall(yplus) - dns_dic = get_dns_data() - plt.clf() - plt.semilogx(yplus, uplus, "r.", yplus, uplus_loglaw, "k:", dns_dic["y+"], dns_dic["Umean"], "b-") - ax = plt.gca() - ax.set_xlim([0.1, 300]) - ax.set_ylim([0, 20]) - fname = "uplus_" + str(timestep // 10000).zfill(5) + ".png" - plt.savefig(fname, format="png") - plt.close() - - -if __name__ == "__main__": - # Problem Configuration - # h: channel half-width - channel_half_width = 50 - - # Define channel geometry based on h - grid_size_x = 6 * channel_half_width - grid_size_y = 3 * channel_half_width - grid_size_z = 2 * channel_half_width - - # Grid parameters - grid_shape = (grid_size_x, grid_size_y, grid_size_z) - - # Define flow regime - # Set up Reynolds number and deduce relaxation time (omega) - Re_tau = 180 - u_tau = 0.001 - - # Runtime & backend configurations - backend = ComputeBackend.WARP - precision_policy = PrecisionPolicy.FP64FP64 - velocity_set = xlb.velocity_set.D3Q27(precision_policy=precision_policy, backend=backend) - num_steps = 10000000 - print_interval = 100000 - - # Print simulation info - print("\n" + "=" * 50 + "\n") - print("Simulation Configuration:") - print(f"Grid size: {grid_size_x} x {grid_size_y} x {grid_size_z}") - print(f"Backend: {backend}") - print(f"Velocity set: {velocity_set}") - print(f"Precision policy: {precision_policy}") - print(f"Reynolds number: {Re_tau}") - print(f"Max iterations: {num_steps}") - print("\n" + "=" * 50 + "\n") - - simulation = TurbulentChannel3D(channel_half_width, Re_tau, u_tau, grid_shape, velocity_set, backend, precision_policy) - simulation.run(num_steps, print_interval, post_process_interval=100000) + + if (step % post_process_interval == 0) or (step == num_steps - 1): + post_process(step, f_0, grid_shape, macro) diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 9f088c30..1f0d266e 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -20,204 +20,243 @@ import matplotlib.pyplot as plt -class WindTunnel3D: - def __init__(self, omega, wind_speed, grid_shape, velocity_set, backend, precision_policy): - # initialize backend - xlb.init( - velocity_set=velocity_set, - default_backend=backend, - default_precision_policy=precision_policy, - ) +# -------------------------- Simulation Setup -------------------------- + +# Grid parameters +grid_size_x, grid_size_y, grid_size_z = 512, 128, 128 +grid_shape = (grid_size_x, grid_size_y, grid_size_z) + +# Simulation Configuration +compute_backend = ComputeBackend.WARP +precision_policy = PrecisionPolicy.FP32FP32 + +velocity_set = xlb.velocity_set.D3Q27(precision_policy=precision_policy, compute_backend=compute_backend) +wind_speed = 0.02 +num_steps = 100000 +print_interval = 1000 +post_process_interval = 1000 + +# Physical Parameters +Re = 50000.0 +clength = grid_size_x - 1 +visc = wind_speed * clength / Re +omega = 1.0 / (3.0 * visc + 0.5) + +# Print simulation info +print("\n" + "=" * 50 + "\n") +print("Simulation Configuration:") +print(f"Grid size: {grid_size_x} x {grid_size_y} x {grid_size_z}") +print(f"Backend: {compute_backend}") +print(f"Velocity set: {velocity_set}") +print(f"Precision policy: {precision_policy}") +print(f"Prescribed velocity: {wind_speed}") +print(f"Reynolds number: {Re}") +print(f"Max iterations: {num_steps}") +print("\n" + "=" * 50 + "\n") + +# Initialize XLB +xlb.init( + velocity_set=velocity_set, + default_backend=compute_backend, + default_precision_policy=precision_policy, +) - self.grid_shape = grid_shape - self.velocity_set = velocity_set - self.backend = backend - self.precision_policy = precision_policy - self.omega = omega - self.boundary_conditions = [] - self.wind_speed = wind_speed - - # Create grid using factory - self.grid = grid_factory(grid_shape, compute_backend=backend) - - # Setup the simulation BC and stepper - self._setup() - - # Make list to store drag coefficients - self.time_steps = [] - self.drag_coefficients = [] - self.lift_coefficients = [] - - def _setup(self): - self.setup_boundary_conditions() - self.setup_stepper() - # Initialize fields using the stepper - self.f_0, self.f_1, self.bc_mask, self.missing_mask = self.stepper.prepare_fields() - - def voxelize_stl(self, stl_filename, length_lbm_unit): - mesh = trimesh.load_mesh(stl_filename, process=False) - length_phys_unit = mesh.extents.max() - pitch = length_phys_unit / length_lbm_unit - mesh_voxelized = mesh.voxelized(pitch=pitch) - mesh_matrix = mesh_voxelized.matrix - return mesh_matrix, pitch - - def define_boundary_indices(self): - box = self.grid.bounding_box_indices() - box_no_edge = self.grid.bounding_box_indices(remove_edges=True) - inlet = box_no_edge["left"] - outlet = box_no_edge["right"] - walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(self.velocity_set.d)] - walls = np.unique(np.array(walls), axis=-1).tolist() - - # Load the mesh (replace with your own mesh) - stl_filename = "../stl-files/DrivAer-Notchback.stl" - mesh = trimesh.load_mesh(stl_filename, process=False) - mesh_vertices = mesh.vertices - - # Transform the mesh points to be located in the right position in the wind tunnel - mesh_vertices -= mesh_vertices.min(axis=0) - mesh_extents = mesh_vertices.max(axis=0) - length_phys_unit = mesh_extents.max() - length_lbm_unit = self.grid_shape[0] / 4 - dx = length_phys_unit / length_lbm_unit - mesh_vertices = mesh_vertices / dx - shift = np.array([self.grid_shape[0] / 4, (self.grid_shape[1] - mesh_extents[1] / dx) / 2, 0.0]) - car = mesh_vertices + shift - self.car_cross_section = np.prod(mesh_extents[1:]) / dx**2 - - return inlet, outlet, walls, car - - def setup_boundary_conditions(self): - inlet, outlet, walls, car = self.define_boundary_indices() - bc_left = RegularizedBC("velocity", prescribed_value=(self.wind_speed, 0.0, 0.0), indices=inlet) - bc_walls = FullwayBounceBackBC(indices=walls) - bc_do_nothing = ExtrapolationOutflowBC(indices=outlet) - bc_car = HalfwayBounceBackBC(mesh_vertices=car) - self.boundary_conditions = [bc_walls, bc_left, bc_do_nothing, bc_car] - - def setup_stepper(self): - self.stepper = IncompressibleNavierStokesStepper( - grid=self.grid, - boundary_conditions=self.boundary_conditions, - collision_type="KBC", - ) +# Create Grid +grid = grid_factory(grid_shape, compute_backend=compute_backend) + +# Bounding box indices +box = grid.bounding_box_indices() +box_no_edge = grid.bounding_box_indices(remove_edges=True) +inlet = box_no_edge["left"] +outlet = box_no_edge["right"] +walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(velocity_set.d)] +walls = np.unique(np.array(walls), axis=-1).tolist() + +# Load the mesh (replace with your own mesh) +stl_filename = "../stl-files/DrivAer-Notchback.stl" +mesh = trimesh.load_mesh(stl_filename, process=False) +mesh_vertices = mesh.vertices + +# Transform the mesh points to align with the grid +mesh_vertices -= mesh_vertices.min(axis=0) +mesh_extents = mesh_vertices.max(axis=0) +length_phys_unit = mesh_extents.max() +length_lbm_unit = grid_shape[0] / 4 +dx = length_phys_unit / length_lbm_unit +mesh_vertices = mesh_vertices / dx +shift = np.array([grid_shape[0] / 4, (grid_shape[1] - mesh_extents[1] / dx) / 2, 0.0]) +car_vertices = mesh_vertices + shift +car_cross_section = np.prod(mesh_extents[1:]) / dx**2 + + +bc_left = RegularizedBC("velocity", prescribed_value=(wind_speed, 0.0, 0.0), indices=inlet) +bc_walls = FullwayBounceBackBC(indices=walls) +bc_do_nothing = ExtrapolationOutflowBC(indices=outlet) +bc_car = HalfwayBounceBackBC(mesh_vertices=car_vertices) +boundary_conditions = [bc_walls, bc_left, bc_do_nothing, bc_car] + + +# Setup Stepper +stepper = IncompressibleNavierStokesStepper( + grid=grid, + boundary_conditions=boundary_conditions, + collision_type="KBC", +) - def run(self, num_steps, print_interval, post_process_interval=100): - # Setup the operator for computing surface forces at the interface of the specified BC - bc_car = self.boundary_conditions[-1] - self.momentum_transfer = MomentumTransfer(bc_car) +# Prepare Fields +f_0, f_1, bc_mask, missing_mask = stepper.prepare_fields() + + +# -------------------------- Helper Functions -------------------------- + + +def plot_drag_coefficient(time_steps, drag_coefficients): + """ + Plot the drag coefficient with various moving averages. + + Args: + time_steps (list): List of time steps. + drag_coefficients (list): List of drag coefficients. + """ + # Convert lists to numpy arrays for processing + time_steps_np = np.array(time_steps) + drag_coefficients_np = np.array(drag_coefficients) + + # Define moving average windows + windows = [10, 100, 1000, 10000, 100000] + labels = ["MA 10", "MA 100", "MA 1,000", "MA 10,000", "MA 100,000"] + + plt.figure(figsize=(12, 8)) + plt.plot(time_steps_np, drag_coefficients_np, label="Raw", alpha=0.5) + + for window, label in zip(windows, labels): + if len(drag_coefficients_np) >= window: + ma = np.convolve(drag_coefficients_np, np.ones(window) / window, mode="valid") + plt.plot(time_steps_np[window - 1 :], ma, label=label) + + plt.ylim(-1.0, 1.0) + plt.legend() + plt.xlabel("Time step") + plt.ylabel("Drag coefficient") + plt.title("Drag Coefficient Over Time with Moving Averages") + plt.savefig("drag_coefficient_ma.png") + plt.close() + + +def post_process( + step, + f_0, + f_1, + grid_shape, + macro, + momentum_transfer, + missing_mask, + bc_mask, + wind_speed, + car_cross_section, + drag_coefficients, + lift_coefficients, + time_steps, +): + """ + Post-process simulation data: save fields, compute forces, and plot drag coefficient. + + Args: + step (int): Current time step. + f_current: Current distribution function. + grid_shape (tuple): Shape of the grid. + macro: Macroscopic operator object. + momentum_transfer: MomentumTransfer operator object. + missing_mask: Missing mask from stepper. + bc_mask: Boundary condition mask from stepper. + wind_speed (float): Prescribed wind speed. + car_cross_section (float): Cross-sectional area of the car. + drag_coefficients (list): List to store drag coefficients. + lift_coefficients (list): List to store lift coefficients. + time_steps (list): List to store time steps. + """ + # Convert to JAX array if necessary + if not isinstance(f_0, jnp.ndarray): + f_0_jax = wp.to_jax(f_0) + else: + f_0_jax = f_0 + + # Compute macroscopic quantities + rho, u = macro(f_0_jax) + + # Remove boundary cells + u = u[:, 1:-1, 1:-1, 1:-1] + u_magnitude = jnp.sqrt(u[0] ** 2 + u[1] ** 2 + u[2] ** 2) + + fields = {"u_magnitude": u_magnitude} + + # Save fields in VTK format + save_fields_vtk(fields, timestep=step) + + # Save the u_magnitude slice at the mid y-plane + mid_y = grid_shape[1] // 2 + save_image(fields["u_magnitude"][:, mid_y, :], timestep=step) + + # Compute lift and drag + boundary_force = momentum_transfer(f_0, f_1, bc_mask, missing_mask) + drag = np.sqrt(boundary_force[0] ** 2 + boundary_force[1] ** 2) # xy-plane + lift = boundary_force[2] + c_d = 2.0 * drag / (wind_speed**2 * car_cross_section) + c_l = 2.0 * lift / (wind_speed**2 * car_cross_section) + drag_coefficients.append(c_d) + lift_coefficients.append(c_l) + time_steps.append(step) + + # Plot drag coefficient + plot_drag_coefficient(time_steps, drag_coefficients) + + +# Setup Momentum Transfer for Force Calculation +bc_car = boundary_conditions[-1] +momentum_transfer = MomentumTransfer(bc_car, compute_backend=compute_backend) + +# Define Macroscopic Calculation +macro = Macroscopic( + compute_backend=ComputeBackend.JAX, + precision_policy=precision_policy, + velocity_set=xlb.velocity_set.D3Q27(precision_policy=precision_policy, compute_backend=ComputeBackend.JAX), +) +# Initialize Lists to Store Coefficients and Time Steps +time_steps = [] +drag_coefficients = [] +lift_coefficients = [] + +# -------------------------- Simulation Loop -------------------------- + +start_time = time.time() +for step in range(num_steps): + # Perform simulation step + f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, omega, step) + f_0, f_1 = f_1, f_0 # Swap the buffers + + # Print progress at intervals + if (step + 1) % print_interval == 0: + elapsed_time = time.time() - start_time + print(f"Iteration: {step + 1}/{num_steps} | Time elapsed: {elapsed_time:.2f}s") start_time = time.time() - for i in range(num_steps): - self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, self.omega, i) - self.f_0, self.f_1 = self.f_1, self.f_0 - - if (i + 1) % print_interval == 0: - elapsed_time = time.time() - start_time - print(f"Iteration: {i + 1}/{num_steps} | Time elapsed: {elapsed_time:.2f}s") - - if i % post_process_interval == 0 or i == num_steps - 1: - self.post_process(i) - - def post_process(self, i): - # Write the results. We'll use JAX backend for the post-processing - if not isinstance(self.f_0, jnp.ndarray): - f_0 = wp.to_jax(self.f_0) - else: - f_0 = self.f_0 - - macro = Macroscopic( - compute_backend=ComputeBackend.JAX, - precision_policy=self.precision_policy, - velocity_set=xlb.velocity_set.D3Q27(precision_policy=self.precision_policy, backend=ComputeBackend.JAX), + + # Post-process at intervals and final step + if (step % post_process_interval == 0) or (step == num_steps - 1): + post_process( + step, + f_0, + f_1, + grid_shape, + macro, + momentum_transfer, + missing_mask, + bc_mask, + wind_speed, + car_cross_section, + drag_coefficients, + lift_coefficients, + time_steps, ) - rho, u = macro(f_0) - - # remove boundary cells - u = u[:, 1:-1, 1:-1, 1:-1] - u_magnitude = (u[0] ** 2 + u[1] ** 2 + u[2] ** 2) ** 0.5 - - fields = {"u_magnitude": u_magnitude} - - save_fields_vtk(fields, timestep=i) - save_image(fields["u_magnitude"][:, self.grid_shape[1] // 2, :], timestep=i) - - # Compute lift and drag - boundary_force = self.momentum_transfer(self.f_0, self.f_1, self.bc_mask, self.missing_mask) - drag = np.sqrt(boundary_force[0] ** 2 + boundary_force[1] ** 2) # xy-plane - lift = boundary_force[2] - c_d = 2.0 * drag / (self.wind_speed**2 * self.car_cross_section) - c_l = 2.0 * lift / (self.wind_speed**2 * self.car_cross_section) - self.drag_coefficients.append(c_d) - self.lift_coefficients.append(c_l) - self.time_steps.append(i) - - # Save monitor plot - self.plot_drag_coefficient() - - def plot_drag_coefficient(self): - # Compute moving average of drag coefficient, 100, 1000, 10000 - drag_coefficients = np.array(self.drag_coefficients) - self.drag_coefficients_ma_10 = np.convolve(drag_coefficients, np.ones(10) / 10, mode="valid") - self.drag_coefficients_ma_100 = np.convolve(drag_coefficients, np.ones(100) / 100, mode="valid") - self.drag_coefficients_ma_1000 = np.convolve(drag_coefficients, np.ones(1000) / 1000, mode="valid") - self.drag_coefficients_ma_10000 = np.convolve(drag_coefficients, np.ones(10000) / 10000, mode="valid") - self.drag_coefficients_ma_100000 = np.convolve(drag_coefficients, np.ones(100000) / 100000, mode="valid") - - # Plot drag coefficient - plt.plot(self.time_steps, drag_coefficients, label="Raw") - if len(self.time_steps) > 10: - plt.plot(self.time_steps[9:], self.drag_coefficients_ma_10, label="MA 10") - if len(self.time_steps) > 100: - plt.plot(self.time_steps[99:], self.drag_coefficients_ma_100, label="MA 100") - if len(self.time_steps) > 1000: - plt.plot(self.time_steps[999:], self.drag_coefficients_ma_1000, label="MA 1,000") - if len(self.time_steps) > 10000: - plt.plot(self.time_steps[9999:], self.drag_coefficients_ma_10000, label="MA 10,000") - if len(self.time_steps) > 100000: - plt.plot(self.time_steps[99999:], self.drag_coefficients_ma_100000, label="MA 100,000") - - plt.ylim(-1.0, 1.0) - plt.legend() - plt.xlabel("Time step") - plt.ylabel("Drag coefficient") - plt.savefig("drag_coefficient_ma.png") - plt.close() - - -if __name__ == "__main__": - # Grid parameters - grid_size_x, grid_size_y, grid_size_z = 512, 128, 128 - grid_shape = (grid_size_x, grid_size_y, grid_size_z) - - # Configuration - backend = ComputeBackend.WARP - precision_policy = PrecisionPolicy.FP32FP32 - - velocity_set = xlb.velocity_set.D3Q27(precision_policy=precision_policy, backend=backend) - wind_speed = 0.02 - num_steps = 100000 - print_interval = 1000 - - # Set up Reynolds number and deduce relaxation time (omega) - Re = 50000.0 - clength = grid_size_x - 1 - visc = wind_speed * clength / Re - omega = 1.0 / (3.0 * visc + 0.5) - - # Print simulation info - print("\n" + "=" * 50 + "\n") - print("Simulation Configuration:") - print(f"Grid size: {grid_size_x} x {grid_size_y} x {grid_size_z}") - print(f"Backend: {backend}") - print(f"Velocity set: {velocity_set}") - print(f"Precision policy: {precision_policy}") - print(f"Prescribed velocity: {wind_speed}") - print(f"Reynolds number: {Re}") - print(f"Max iterations: {num_steps}") - print("\n" + "=" * 50 + "\n") - - simulation = WindTunnel3D(omega, wind_speed, grid_shape, velocity_set, backend, precision_policy) - simulation.run(num_steps, print_interval, post_process_interval=1000) +print("Simulation completed successfully.") diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py index d014b91c..3d18fd99 100644 --- a/examples/performance/mlups_3d.py +++ b/examples/performance/mlups_3d.py @@ -10,18 +10,18 @@ from xlb.operator.boundary_condition import FullwayBounceBackBC, EquilibriumBC from xlb.distribute import distribute +# -------------------------- Simulation Setup -------------------------- def parse_arguments(): parser = argparse.ArgumentParser(description="MLUPS for 3D Lattice Boltzmann Method Simulation (BGK)") parser.add_argument("cube_edge", type=int, help="Length of the edge of the cubic grid") - parser.add_argument("num_steps", type=int, help="Timestep for the simulation") - parser.add_argument("backend", type=str, help="Backend for the simulation (jax or warp)") + parser.add_argument("num_steps", type=int, help="Number of timesteps for the simulation") + parser.add_argument("compute_backend", type=str, help="Backend for the simulation (jax or warp)") parser.add_argument("precision", type=str, help="Precision for the simulation (e.g., fp32/fp32)") return parser.parse_args() - def setup_simulation(args): - backend = ComputeBackend.JAX if args.backend == "jax" else ComputeBackend.WARP + compute_backend = ComputeBackend.JAX if args.compute_backend == "jax" else ComputeBackend.WARP precision_policy_map = { "fp32/fp32": PrecisionPolicy.FP32FP32, "fp64/fp64": PrecisionPolicy.FP64FP64, @@ -30,68 +30,78 @@ def setup_simulation(args): } precision_policy = precision_policy_map.get(args.precision) if precision_policy is None: - raise ValueError("Invalid precision") + raise ValueError("Invalid precision specified.") xlb.init( - velocity_set=xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend), - default_backend=backend, + velocity_set=xlb.velocity_set.D3Q19(precision_policy=precision_policy, compute_backend=compute_backend), + default_backend=compute_backend, default_precision_policy=precision_policy, ) + return compute_backend, precision_policy - return backend, precision_policy - - -def run(backend, precision_policy, grid_shape, num_steps): - # Create grid and setup boundary conditions +def run_simulation(compute_backend, precision_policy, grid_shape, num_steps): grid = grid_factory(grid_shape) box = grid.bounding_box_indices() box_no_edge = grid.bounding_box_indices(remove_edges=True) + lid = box_no_edge["top"] - walls = [box["bottom"][i] + box["left"][i] + box["right"][i] + box["front"][i] + box["back"][i] for i in range(len(grid.shape))] + walls = [ + box["bottom"][i] + box["left"][i] + box["right"][i] + box["front"][i] + box["back"][i] + for i in range(len(grid.shape)) + ] walls = np.unique(np.array(walls), axis=-1).tolist() - boundary_conditions = [EquilibriumBC(rho=1.0, u=(0.02, 0.0, 0.0), indices=lid), FullwayBounceBackBC(indices=walls)] + boundary_conditions = [ + EquilibriumBC(rho=1.0, u=(0.02, 0.0, 0.0), indices=lid), + FullwayBounceBackBC(indices=walls), + ] - # Create stepper - stepper = IncompressibleNavierStokesStepper(grid=grid, boundary_conditions=boundary_conditions, collision_type="BGK") + stepper = IncompressibleNavierStokesStepper( + grid=grid, + boundary_conditions=boundary_conditions, + collision_type="BGK", + ) - # Distribute if using JAX backend - if backend == ComputeBackend.JAX: + # Distribute if using JAX + if compute_backend == ComputeBackend.JAX: stepper = distribute( stepper, grid, - xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend), + xlb.velocity_set.D3Q19(precision_policy=precision_policy, compute_backend=compute_backend), ) - # Initialize fields and run simulation + # Initialize fields omega = 1.0 f_0, f_1, bc_mask, missing_mask = stepper.prepare_fields() - start_time = time.time() + start_time = time.time() for i in range(num_steps): f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, omega, i) f_0, f_1 = f_1, f_0 wp.synchronize() + elapsed_time = time.time() - start_time - return time.time() - start_time - + return elapsed_time def calculate_mlups(cube_edge, num_steps, elapsed_time): total_lattice_updates = cube_edge**3 * num_steps mlups = (total_lattice_updates / elapsed_time) / 1e6 return mlups +# -------------------------- Simulation Loop -------------------------- -def main(): - args = parse_arguments() - backend, precision_policy = setup_simulation(args) - grid_shape = (args.cube_edge, args.cube_edge, args.cube_edge) - elapsed_time = run(backend, precision_policy, grid_shape, args.num_steps) - mlups = calculate_mlups(args.cube_edge, args.num_steps, elapsed_time) +args = parse_arguments() +compute_backend, precision_policy = setup_simulation(args) +grid_shape = (args.cube_edge, args.cube_edge, args.cube_edge) - print(f"Simulation completed in {elapsed_time:.2f} seconds") - print(f"MLUPs: {mlups:.2f}") +elapsed_time = run_simulation( + compute_backend=compute_backend, + precision_policy=precision_policy, + grid_shape=grid_shape, + num_steps=args.num_steps +) +mlups = calculate_mlups(args.cube_edge, args.num_steps, elapsed_time) -if __name__ == "__main__": - main() +print(f"Simulation completed in {elapsed_time:.2f} seconds") +print(f"MLUPs: {mlups:.2f}") diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py index 1025e3bd..94a66e5e 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py @@ -9,7 +9,7 @@ def init_xlb_env(velocity_set): - vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX) + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, compute_backend=ComputeBackend.JAX) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py index 6bd9311b..07e68cf4 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py @@ -8,7 +8,7 @@ def init_xlb_env(velocity_set): - vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, compute_backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.WARP, diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py index cd189754..b4bc797b 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py @@ -10,7 +10,7 @@ def init_xlb_env(velocity_set): - vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX) + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, compute_backend=ComputeBackend.JAX) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py index 59c6c9d7..3cc15cb3 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py @@ -9,7 +9,7 @@ def init_xlb_env(velocity_set): - vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, compute_backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.WARP, diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py index 79d56d89..eb81eda3 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py @@ -8,7 +8,7 @@ def init_xlb_env(velocity_set): - vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX) + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, compute_backend=ComputeBackend.JAX) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py index cebc23f2..4ec0639e 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py @@ -8,7 +8,7 @@ def init_xlb_env(velocity_set): - vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, compute_backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.WARP, diff --git a/tests/grids/test_grid_jax.py b/tests/grids/test_grid_jax.py index dd74da64..7255b7f6 100644 --- a/tests/grids/test_grid_jax.py +++ b/tests/grids/test_grid_jax.py @@ -9,7 +9,7 @@ def init_xlb_env(velocity_set): - vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, compute_backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, diff --git a/tests/grids/test_grid_warp.py b/tests/grids/test_grid_warp.py index 61b27d4e..2f1ab232 100644 --- a/tests/grids/test_grid_warp.py +++ b/tests/grids/test_grid_warp.py @@ -8,7 +8,7 @@ def init_xlb_env(velocity_set): - vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, compute_backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.WARP, diff --git a/tests/kernels/collision/test_bgk_collision_jax.py b/tests/kernels/collision/test_bgk_collision_jax.py index 72c2ec99..f3f4308f 100644 --- a/tests/kernels/collision/test_bgk_collision_jax.py +++ b/tests/kernels/collision/test_bgk_collision_jax.py @@ -9,7 +9,7 @@ def init_xlb_env(velocity_set): - vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX) + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, compute_backend=ComputeBackend.JAX) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, diff --git a/tests/kernels/collision/test_bgk_collision_warp.py b/tests/kernels/collision/test_bgk_collision_warp.py index 3c8436c6..aa51ea1d 100644 --- a/tests/kernels/collision/test_bgk_collision_warp.py +++ b/tests/kernels/collision/test_bgk_collision_warp.py @@ -9,7 +9,7 @@ def init_xlb_env(velocity_set): - vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, compute_backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.WARP, diff --git a/tests/kernels/equilibrium/test_equilibrium_jax.py b/tests/kernels/equilibrium/test_equilibrium_jax.py index aa4f051a..e0451845 100644 --- a/tests/kernels/equilibrium/test_equilibrium_jax.py +++ b/tests/kernels/equilibrium/test_equilibrium_jax.py @@ -8,7 +8,7 @@ def init_xlb_env(velocity_set): - vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX) + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, compute_backend=ComputeBackend.JAX) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, diff --git a/tests/kernels/equilibrium/test_equilibrium_warp.py b/tests/kernels/equilibrium/test_equilibrium_warp.py index fdd796a4..8b7e8f44 100644 --- a/tests/kernels/equilibrium/test_equilibrium_warp.py +++ b/tests/kernels/equilibrium/test_equilibrium_warp.py @@ -8,7 +8,7 @@ def init_xlb_env(velocity_set): - vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, compute_backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.WARP, diff --git a/tests/kernels/macroscopic/test_macroscopic_jax.py b/tests/kernels/macroscopic/test_macroscopic_jax.py index 2c2ad55e..87e44770 100644 --- a/tests/kernels/macroscopic/test_macroscopic_jax.py +++ b/tests/kernels/macroscopic/test_macroscopic_jax.py @@ -8,7 +8,7 @@ def init_xlb_env(velocity_set): - vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX) + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, compute_backend=ComputeBackend.JAX) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, diff --git a/tests/kernels/macroscopic/test_macroscopic_warp.py b/tests/kernels/macroscopic/test_macroscopic_warp.py index 4f33bc23..b4be9a21 100644 --- a/tests/kernels/macroscopic/test_macroscopic_warp.py +++ b/tests/kernels/macroscopic/test_macroscopic_warp.py @@ -9,7 +9,7 @@ def init_xlb_env(velocity_set): - vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, compute_backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.WARP, diff --git a/tests/kernels/stream/test_stream_jax.py b/tests/kernels/stream/test_stream_jax.py index c1cae525..015c21c6 100644 --- a/tests/kernels/stream/test_stream_jax.py +++ b/tests/kernels/stream/test_stream_jax.py @@ -8,7 +8,7 @@ def init_xlb_env(velocity_set): - vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX) + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, compute_backend=ComputeBackend.JAX) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, diff --git a/tests/kernels/stream/test_stream_warp.py b/tests/kernels/stream/test_stream_warp.py index 95fcc053..e10cc467 100644 --- a/tests/kernels/stream/test_stream_warp.py +++ b/tests/kernels/stream/test_stream_warp.py @@ -10,7 +10,7 @@ def init_xlb_env(velocity_set): - vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, compute_backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.WARP, diff --git a/xlb/grid/jax_grid.py b/xlb/grid/jax_grid.py index 24eeb031..9002cec3 100644 --- a/xlb/grid/jax_grid.py +++ b/xlb/grid/jax_grid.py @@ -20,7 +20,7 @@ def __init__(self, shape): def _initialize_backend(self): self.nDevices = jax.device_count() - self.backend = jax.default_backend() + self.compute_backend = jax.default_backend() self.device_mesh = ( mesh_utils.create_device_mesh((1, self.nDevices, 1)) if self.dim == 2 else mesh_utils.create_device_mesh((1, self.nDevices, 1, 1)) ) diff --git a/xlb/helper/check_boundary_overlaps.py b/xlb/helper/check_boundary_overlaps.py index 1adceb2b..831939f0 100644 --- a/xlb/helper/check_boundary_overlaps.py +++ b/xlb/helper/check_boundary_overlaps.py @@ -2,7 +2,7 @@ from xlb.compute_backend import ComputeBackend -def check_bc_overlaps(bclist, dim, backend): +def check_bc_overlaps(bclist, dim, compute_backend): index_list = [[] for _ in range(dim)] for bc in bclist: if bc.indices is None: @@ -10,7 +10,7 @@ def check_bc_overlaps(bclist, dim, backend): # Detect duplicates within bc.indices index_arr = np.unique(bc.indices, axis=-1) if index_arr.shape[-1] != len(bc.indices[0]): - if backend == ComputeBackend.WARP: + if compute_backend == ComputeBackend.WARP: raise ValueError(f"Boundary condition {bc.__class__.__name__} has duplicate indices!") print(f"WARNING: there are duplicate indices in {bc.__class__.__name__} and hence the order in bc list matters!") for d in range(dim): @@ -19,6 +19,6 @@ def check_bc_overlaps(bclist, dim, backend): # Detect duplicates within bclist index_arr = np.unique(index_list, axis=-1) if index_arr.shape[-1] != len(index_list[0]): - if backend == ComputeBackend.WARP: + if compute_backend == ComputeBackend.WARP: raise ValueError("Boundary condition list containes duplicate indices!") print("WARNING: there are duplicate indices in the boundary condition list and hence the order in this list matters!") diff --git a/xlb/helper/initializers.py b/xlb/helper/initializers.py index ccb4a82f..487d2cfa 100644 --- a/xlb/helper/initializers.py +++ b/xlb/helper/initializers.py @@ -2,17 +2,17 @@ from xlb.operator.equilibrium import QuadraticEquilibrium -def initialize_eq(f, grid, velocity_set, precision_policy, backend, rho=None, u=None): +def initialize_eq(f, grid, velocity_set, precision_policy, compute_backend, rho=None, u=None): if rho is None: rho = grid.create_field(cardinality=1, fill_value=1.0, dtype=precision_policy.compute_precision) if u is None: u = grid.create_field(cardinality=velocity_set.d, fill_value=0.0, dtype=precision_policy.compute_precision) equilibrium = QuadraticEquilibrium() - if backend == ComputeBackend.JAX: + if compute_backend == ComputeBackend.JAX: f = equilibrium(rho, u) - elif backend == ComputeBackend.WARP: + elif compute_backend == ComputeBackend.WARP: f = equilibrium(rho, u, f) del rho, u diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index 6e8bbbbb..39a0b861 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -22,11 +22,11 @@ def __init__(self, velocity_set=None, precision_policy=None, compute_backend=Non self.precision_policy = precision_policy or DefaultConfig.default_precision_policy self.compute_backend = compute_backend or DefaultConfig.default_backend - # Check if the compute backend is supported + # Check if the compute compute_backend is supported if self.compute_backend not in ComputeBackend: - raise ValueError(f"Compute backend {compute_backend} is not supported") + raise ValueError(f"Compute_backend {compute_backend} is not supported") - # Construct the kernel based backend functions TODO: Maybe move this to the register or something + # Construct the kernel based compute_backend functions TODO: Maybe move this to the register or something if self.compute_backend == ComputeBackend.WARP: self.warp_functional, self.warp_kernel = self._construct_warp() @@ -39,7 +39,7 @@ def __init__(self, velocity_set=None, precision_policy=None, compute_backend=Non @classmethod def register_backend(cls, backend_name): """ - Decorator to register a backend for the operator. + Decorator to register a compute_backend for the operator. """ def decorator(func): @@ -58,7 +58,7 @@ def __call__(self, *args, callback=None, **kwargs): bound_arguments = None for key, backend_method in method_candidates: try: - # This attempts to bind the provided args and kwargs to the backend method's signature + # This attempts to bind the provided args and kwargs to the compute_backend method's signature bound_arguments = inspect.signature(backend_method).bind(self, *args, **kwargs) bound_arguments.apply_defaults() # This fills in any default values result = backend_method(self, *args, **kwargs) @@ -99,10 +99,10 @@ def backend(self): This should be used with caution as all backends may not have the same API. """ if self.compute_backend == ComputeBackend.JAX: - import jax.numpy as backend + import jax.numpy as compute_backend elif self.compute_backend == ComputeBackend.WARP: - import warp as backend - return backend + import warp as compute_backend + return compute_backend @property def compute_dtype(self): @@ -128,7 +128,8 @@ def _construct_warp(self): """ Construct the warp functional and kernel of the operator TODO: Maybe a better way to do this? - Maybe add this to the backend decorator? - Leave it for now, as it is not clear how the warp backend will evolve + Maybe add this to the compute backend decorator? + Leave it for now, as it is not clear how the warp compute backend will evolve """ return None, None + diff --git a/xlb/velocity_set/d2q9.py b/xlb/velocity_set/d2q9.py index 5324618d..69dad633 100644 --- a/xlb/velocity_set/d2q9.py +++ b/xlb/velocity_set/d2q9.py @@ -13,7 +13,7 @@ class D2Q9(VelocitySet): Lattice Boltzmann Method for simulating fluid flows in two dimensions. """ - def __init__(self, precision_policy, backend): + def __init__(self, precision_policy, compute_backend): # Construct the velocity vectors and weights cx = [0, 0, 0, 1, -1, 1, -1, 1, -1] cy = [0, 1, -1, 0, 1, -1, 0, 1, -1] @@ -21,4 +21,4 @@ def __init__(self, precision_policy, backend): w = np.array([4 / 9, 1 / 9, 1 / 9, 1 / 9, 1 / 36, 1 / 36, 1 / 9, 1 / 36, 1 / 36]) # Call the parent constructor - super().__init__(2, 9, c, w, precision_policy=precision_policy, backend=backend) + super().__init__(2, 9, c, w, precision_policy=precision_policy, compute_backend=compute_backend) diff --git a/xlb/velocity_set/d3q19.py b/xlb/velocity_set/d3q19.py index 4a48c2f0..ae4ea994 100644 --- a/xlb/velocity_set/d3q19.py +++ b/xlb/velocity_set/d3q19.py @@ -14,7 +14,7 @@ class D3Q19(VelocitySet): Lattice Boltzmann Method for simulating fluid flows in three dimensions. """ - def __init__(self, precision_policy, backend): + def __init__(self, precision_policy, compute_backend): # Construct the velocity vectors and weights c = np.array([ci for ci in itertools.product([-1, 0, 1], repeat=3) if np.sum(np.abs(ci)) <= 2]).T w = np.zeros(19) @@ -27,4 +27,4 @@ def __init__(self, precision_policy, backend): w[i] = 1.0 / 36.0 # Initialize the lattice - super().__init__(3, 19, c, w, precision_policy=precision_policy, backend=backend) + super().__init__(3, 19, c, w, precision_policy=precision_policy, compute_backend=compute_backend) diff --git a/xlb/velocity_set/d3q27.py b/xlb/velocity_set/d3q27.py index b056d53e..8110fd6c 100644 --- a/xlb/velocity_set/d3q27.py +++ b/xlb/velocity_set/d3q27.py @@ -14,7 +14,7 @@ class D3Q27(VelocitySet): Lattice Boltzmann Method for simulating fluid flows in three dimensions. """ - def __init__(self, precision_policy, backend): + def __init__(self, precision_policy, compute_backend): # Construct the velocity vectors and weights c = np.array(list(itertools.product([0, -1, 1], repeat=3))).T w = np.zeros(27) @@ -29,4 +29,4 @@ def __init__(self, precision_policy, backend): w[i] = 1.0 / 216.0 # Initialize the Lattice - super().__init__(3, 27, c, w, precision_policy=precision_policy, backend=backend) + super().__init__(3, 27, c, w, precision_policy=precision_policy, compute_backend=compute_backend) diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index 33b2331b..8b8d3213 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -27,27 +27,27 @@ class VelocitySet(object): The weights of the lattice. Shape: (q,) """ - def __init__(self, d, q, c, w, precision_policy, backend): + def __init__(self, d, q, c, w, precision_policy, compute_backend): # Store the dimension and the number of velocities self.d = d self.q = q self.precision_policy = precision_policy - self.backend = backend + self.compute_backend = compute_backend # Updating JAX config in case fp64 is requested - if backend == ComputeBackend.JAX and (precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32): + if compute_backend == ComputeBackend.JAX and (precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32): jax.config.update("jax_enable_x64", True) # Create all properties in NumPy first self._init_numpy_properties(c, w) # Convert properties to backend-specific format - if self.backend == ComputeBackend.WARP: + if self.compute_backend == ComputeBackend.WARP: self._init_warp_properties() - elif self.backend == ComputeBackend.JAX: + elif self.compute_backend == ComputeBackend.JAX: self._init_jax_properties() else: - raise ValueError(f"Unsupported compute backend: {self.backend}") + raise ValueError(f"Unsupported compute backend: {self.compute_backend}") # Set up backend-specific constants self._init_backend_constants() @@ -94,18 +94,19 @@ def _init_jax_properties(self): self.w = jnp.array(self._w, dtype=dtype) self.opp_indices = jnp.array(self._opp_indices, dtype=jnp.int32) self.cc = jnp.array(self._cc, dtype=dtype) + self.c_float = jnp.array(self._c_float, dtype=dtype) self.qi = jnp.array(self._qi, dtype=dtype) def _init_backend_constants(self): """ Initialize the constants for the backend. """ - if self.backend == ComputeBackend.WARP: + if self.compute_backend == ComputeBackend.WARP: dtype = self.precision_policy.compute_precision.wp_dtype self.cs = wp.constant(dtype(self.cs)) self.cs2 = wp.constant(dtype(self.cs2)) self.inv_cs2 = wp.constant(dtype(self.inv_cs2)) - elif self.backend == ComputeBackend.JAX: + elif self.compute_backend == ComputeBackend.JAX: dtype = self.precision_policy.compute_precision.jax_dtype self.cs = jnp.array(self.cs, dtype=dtype) self.cs2 = jnp.array(self.cs2, dtype=dtype)