Skip to content

Commit

Permalink
compiler:m split compilation and argument processing
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jan 20, 2025
1 parent 9b8a144 commit 9e8f4df
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 170 deletions.
16 changes: 8 additions & 8 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,19 +627,17 @@ class SafeInv(Differentiable, sympy.core.function.Application):
_fd_priority = 0

def __new__(cls, val, base, **kwargs):
newobj = super().__new__(cls, val, base)
return newobj

def _evaluate(self, **kwargs):
return self.func(*[i._evaluated(**kwargs) for i in self.args])

def _subs(self, old, new, **hints):
return self.func(*[i._subs(old, new, **hints) for i in self.args])
return super().__new__(cls, val, base)

@property
def base(self):
return self.args[1]

def __str__(self):
return f'1/({self.args[0]})'

__repr__ = __str__


class IndexSum(sympy.Expr, Evaluable):

Expand Down Expand Up @@ -693,6 +691,8 @@ def __repr__(self):
def _sympystr(self, printer):
return str(self)

_latex = _sympystr

def _hashable_content(self):
return super()._hashable_content() + (self.dimensions,)

Expand Down
8 changes: 7 additions & 1 deletion devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,14 +894,20 @@ def apply(self, **kwargs):
>>> op = Operator(Eq(u3.forward, u3 + 1))
>>> summary = op.apply(time_M=10)
"""
# Compile the operator before building the arguments list
# to avoid out of memory with greedy compilers
try:
cfunction = self.cfunction
except Exception as e:
raise CompilationError(f"Failed to compile the Operator {self.name}") from e

# Build the arguments list to invoke the kernel function
with self._profiler.timer_on('arguments'):
args = self.arguments(**kwargs)

# Invoke kernel function with args
arg_values = [args[p.name] for p in self.parameters]
try:
cfunction = self.cfunction
with self._profiler.timer_on('apply', comm=args.comm):
retval = cfunction(*arg_values)
except ctypes.ArgumentError as e:
Expand Down
2 changes: 1 addition & 1 deletion devito/symbolics/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ def estimate_cost(exprs, estimate=False):
estimate_values = {
'elementary': 100,
'pow': 50,
'SafeInv': 10,
'div': 5,
'Abs': 5,
'floor': 1,
'ceil': 1,
'SafeDiv': 10,
}


Expand Down
280 changes: 122 additions & 158 deletions examples/seismic/tutorials/06_elastic_varying_parameters.ipynb

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions tests/test_differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import pytest

from devito import Function, Grid, Differentiable, NODE
from devito.finite_differences.differentiable import Add, Mul, Pow, diffify, interp_for_fd
from devito.finite_differences.differentiable import (Add, Mul, Pow, diffify,
interp_for_fd, SafeInv)


def test_differentiable():
Expand Down Expand Up @@ -113,4 +114,7 @@ def test_avg_mode(ndim):
assert sympy.simplify(a_avg - 0.5**ndim * sum(a.subs(arg) for arg in args)) == 0

# Harmonic average, h(a[.5]) = 1/(.5/a[0] + .5/a[1])
assert sympy.simplify(b_avg - 1/(0.5**ndim * sum(1/b.subs(arg) for arg in args))) == 0
expected = 1/(0.5**ndim * sum(1/b.subs(arg) for arg in args))
assert sympy.simplify(1/b_avg.args[0] - expected) == 0
assert isinstance(b_avg, SafeInv)
assert b_avg.base == b

0 comments on commit 9e8f4df

Please sign in to comment.