Skip to content

Commit

Permalink
Separated zero and first kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
mehdiataei committed Sep 19, 2024
1 parent 3ff0814 commit c4c994b
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 2 deletions.
6 changes: 4 additions & 2 deletions xlb/operator/macroscopic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from xlb.operator.macroscopic.zero_first_moments import ZeroAndFirstMoments as Macroscopic
from xlb.operator.macroscopic.second_moment import SecondMoment as SecondMoment
from xlb.operator.macroscopic.macroscopic import Macroscopic
from xlb.operator.macroscopic.second_moment import SecondMoment
from xlb.operator.macroscopic.zero_moment import ZeroMoment
from xlb.operator.macroscopic.first_moment import FirstMoment
83 changes: 83 additions & 0 deletions xlb/operator/macroscopic/first_moment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from functools import partial
import jax.numpy as jnp
from jax import jit
import warp as wp
from typing import Any

from xlb.compute_backend import ComputeBackend
from xlb.operator.operator import Operator

class FirstMoment(Operator):
"""A class to compute the first moment (velocity) of distribution functions."""

@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0), inline=True)
def jax_implementation(self, f, rho):
u = jnp.tensordot(self.velocity_set.c, f, axes=(-1, 0)) / rho
return u

def _construct_warp(self):
_c = self.velocity_set.c
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)

@wp.func
def functional(f: _f_vec, rho: float):
u = _u_vec()
for l in range(self.velocity_set.q):
for d in range(self.velocity_set.d):
if _c[d, l] == 1:
u[d] += f[l]
elif _c[d, l] == -1:
u[d] -= f[l]
u /= rho
return u

@wp.kernel
def kernel3d(
f: wp.array4d(dtype=Any),
rho: wp.array4d(dtype=Any),
u: wp.array4d(dtype=Any),
):
i, j, k = wp.tid()
index = wp.vec3i(i, j, k)

_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f[l, index[0], index[1], index[2]]
_rho = rho[0, index[0], index[1], index[2]]
_u = functional(_f, _rho)

for d in range(self.velocity_set.d):
u[d, index[0], index[1], index[2]] = _u[d]

@wp.kernel
def kernel2d(
f: wp.array3d(dtype=Any),
rho: wp.array3d(dtype=Any),
u: wp.array3d(dtype=Any),
):
i, j = wp.tid()
index = wp.vec2i(i, j)

_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f[l, index[0], index[1]]
_rho = rho[0, index[0], index[1]]
_u = functional(_f, _rho)

for d in range(self.velocity_set.d):
u[d, index[0], index[1]] = _u[d]

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

return functional, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f, rho, u):
wp.launch(
self.warp_kernel,
inputs=[f, rho, u],
dim=u.shape[1:],
)
return u
85 changes: 85 additions & 0 deletions xlb/operator/macroscopic/macroscopic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from functools import partial
import jax.numpy as jnp
from jax import jit
import warp as wp
from typing import Any

from xlb.compute_backend import ComputeBackend
from xlb.operator.operator import Operator
from xlb.operator.macroscopic.zero_moment import ZeroMoment
from xlb.operator.macroscopic.first_moment import FirstMoment

class Macroscopic(Operator):
"""A class to compute both zero and first moments of distribution functions (rho, u)."""

def __init__(self, *args, **kwargs):
self.zero_moment = ZeroMoment(*args, **kwargs)
self.first_moment = FirstMoment(*args, **kwargs)
super().__init__(*args, **kwargs)

@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0), inline=True)
def jax_implementation(self, f):
rho = self.zero_moment(f)
u = self.first_moment(f, rho)
return rho, u

def _construct_warp(self):
zero_moment_func = self.zero_moment.warp_functional
first_moment_func = self.first_moment.warp_functional
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)

@wp.func
def functional(f: _f_vec):
rho = zero_moment_func(f)
u = first_moment_func(f, rho)
return rho, u

@wp.kernel
def kernel3d(
f: wp.array4d(dtype=Any),
rho: wp.array4d(dtype=Any),
u: wp.array4d(dtype=Any),
):
i, j, k = wp.tid()
index = wp.vec3i(i, j, k)

_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f[l, index[0], index[1], index[2]]
_rho, _u = functional(_f)

rho[0, index[0], index[1], index[2]] = _rho
for d in range(self.velocity_set.d):
u[d, index[0], index[1], index[2]] = _u[d]

@wp.kernel
def kernel2d(
f: wp.array3d(dtype=Any),
rho: wp.array3d(dtype=Any),
u: wp.array3d(dtype=Any),
):
i, j = wp.tid()
index = wp.vec2i(i, j)

_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f[l, index[0], index[1]]
_rho, _u = functional(_f)

rho[0, index[0], index[1]] = _rho
for d in range(self.velocity_set.d):
u[d, index[0], index[1]] = _u[d]

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

return functional, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f, rho, u):
wp.launch(
self.warp_kernel,
inputs=[f, rho, u],
dim=rho.shape[1:],
)
return rho, u
69 changes: 69 additions & 0 deletions xlb/operator/macroscopic/zero_moment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from functools import partial
import jax.numpy as jnp
from jax import jit
import warp as wp
from typing import Any

from xlb.compute_backend import ComputeBackend
from xlb.operator.operator import Operator

class ZeroMoment(Operator):
"""A class to compute the zeroth moment (density) of distribution functions."""

@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0), inline=True)
def jax_implementation(self, f):
return jnp.sum(f, axis=0, keepdims=True)

def _construct_warp(self):
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)

@wp.func
def functional(f: _f_vec):
rho = self.compute_dtype(0.0)
for l in range(self.velocity_set.q):
rho += f[l]
return rho

@wp.kernel
def kernel3d(
f: wp.array4d(dtype=Any),
rho: wp.array4d(dtype=Any),
):
i, j, k = wp.tid()
index = wp.vec3i(i, j, k)

_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f[l, index[0], index[1], index[2]]
_rho = functional(_f)

rho[0, index[0], index[1], index[2]] = _rho

@wp.kernel
def kernel2d(
f: wp.array3d(dtype=Any),
rho: wp.array3d(dtype=Any),
):
i, j = wp.tid()
index = wp.vec2i(i, j)

_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f[l, index[0], index[1]]
_rho = functional(_f)

rho[0, index[0], index[1]] = _rho

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

return functional, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f, rho):
wp.launch(
self.warp_kernel,
inputs=[f, rho],
dim=rho.shape[1:],
)
return rho

0 comments on commit c4c994b

Please sign in to comment.