Skip to content

Commit

Permalink
compiler: make SafeInv dtype handled (eps, plain div)
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jan 20, 2025
1 parent 9958b9b commit 8c11f8d
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 7 deletions.
9 changes: 5 additions & 4 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,15 +626,16 @@ class Mod(DifferentiableOp, sympy.Mod):
class SafeInv(Differentiable, sympy.core.function.Application):
_fd_priority = 0

def __new__(cls, val, base, **kwargs):
return super().__new__(cls, val, base)

@property
def base(self):
return self.args[1]

@property
def val(self):
return self.args[0]

def __str__(self):
return f'1/({self.args[0]})'
return Pow(self.args[0], -1).__str__()

__repr__ = __str__

Expand Down
7 changes: 6 additions & 1 deletion devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,12 @@ def _(expr):

@_lower_macro_math.register(SafeInv)
def _(expr):
return (('SAFEINV(a, b)', '(((a) < 1e-12 || (b) < 1e-12) ? (0.0F) : (1.0F / (a)))'),)
if expr.dtype(0).itemsize <= 4:
eps = np.finfo(expr.dtype).resolution**2
return (('SAFEINV(a, b)',
f'(((a) < {eps} || (b) < {eps}) ? (0.0F) : (1.0F / (a)))'),)
else:
return ()


@iet_pass
Expand Down
5 changes: 4 additions & 1 deletion devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,11 @@ def _print_Pow(self, expr):

def _print_SafeInv(self, expr):
"""Print a SafeInv as a C-like division with a check for zero."""
if not self.single_prec(expr):
return self._print(1/expr.val)

base = self._print(expr.base)
val = self._print(expr.args[0])
val = self._print(expr.val)
return f'SAFEINV({val}, {base})'

def _print_Mod(self, expr):
Expand Down
2 changes: 1 addition & 1 deletion devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,7 +1172,7 @@ def _evaluate(self, **kwargs):
# is called again within FD
if self._avg_mode == 'harmonic':
from devito.finite_differences.differentiable import SafeInv
retval = SafeInv(retval.evaluate, base=self.function)
retval = SafeInv(retval.evaluate, self.function)
else:
retval = retval.evaluate

Expand Down
15 changes: 15 additions & 0 deletions tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa
Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos,
Min, Max)
from devito.finite_differences.differentiable import SafeInv
from devito.ir import Expression, FindNodes
from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa
CallFromPointer, Cast, DefFunction, FieldFromPointer,
Expand Down Expand Up @@ -345,6 +346,20 @@ def test_intdiv():
assert ccode(v) == 'b*((a + b) / 2) + 3'


def test_safeinv():
grid = Grid(shape=(11, 11))
x, y = grid.dimensions

u1 = Function(name='u', grid=grid)
u2 = Function(name='u', grid=grid, dtype=np.float64)

op1 = Operator(Eq(u1, SafeInv(u1, u1)))
op2 = Operator(Eq(u2, SafeInv(u2, u2)))

assert 'SAFEINV' in str(op1)
assert 'SAFEINV' not in str(op2)


def test_def_function():
foo0 = DefFunction('foo', arguments=['a', 'b'], template=['int'])
foo1 = DefFunction('foo', arguments=['a', 'b'], template=['int'])
Expand Down

0 comments on commit 8c11f8d

Please sign in to comment.