Skip to content

Commit

Permalink
compiler: Add condition not to increment during injection, disregardi…
Browse files Browse the repository at this point in the history
…ng interpolation weights.
  • Loading branch information
fffarias committed May 24, 2022
1 parent 77aec6c commit 4feee17
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,14 @@ class Injection(UnevaluatedSparseOperation):
Evaluates to a list of Eq objects.
"""

def __new__(cls, field, expr, offset, interpolator, callback):
def __new__(cls, field, expr, offset, increment, interpolator, callback):
obj = super().__new__(cls, interpolator, callback)

# TODO: unused now, but will be necessary to compute the adjoint
obj.field = field
obj.expr = expr
obj.offset = offset
obj.increment = increment

return obj

Expand Down Expand Up @@ -255,7 +256,7 @@ def callback():

return Interpolation(expr, offset, increment, self_subs, self, callback)

def inject(self, field, expr, offset=0):
def inject(self, field, expr, offset=0, increment=True):
"""
Generate equations injecting an arbitrary expression into a field.
Expand All @@ -267,6 +268,8 @@ def inject(self, field, expr, offset=0):
Injected expression.
offset : int, optional
Additional offset from the boundary.
increment: bool, optional
If True, generate increments (Inc) rather than assignments (Eq).
"""
def callback():
# Derivatives must be evaluated before the introduction of indirect accesses
Expand All @@ -285,13 +288,18 @@ def callback():
field_offset=field_offset)

# Substitute coordinate base symbols into the interpolation coefficients
eqns = [Inc(field.xreplace(vsub), _expr.xreplace(vsub) * b,
implicit_dims=self.sfunction.dimensions)
for b, vsub in zip(self._interpolation_coeffs, idx_subs)]
if increment:
eqns = [Inc(field.xreplace(vsub), _expr.xreplace(vsub) * b,
implicit_dims=self.sfunction.dimensions)
for b, vsub in zip(self._interpolation_coeffs, idx_subs)]
else:
eqns = [Eq(field.xreplace(idx_subs[0]), _expr.xreplace(idx_subs[0]),
implicit_dims=self.sfunction.dimensions)]


return temps + eqns

return Injection(field, expr, offset, self, callback)
return Injection(field, expr, offset, increment, self, callback)


class PrecomputedInterpolator(GenericInterpolator):
Expand Down Expand Up @@ -388,3 +396,4 @@ def callback():
return [Eq(_field, _field + rhs.subs(dim_subs))]

return Injection(field, expr, offset, self, callback)

0 comments on commit 4feee17

Please sign in to comment.