Skip to content

Commit

Permalink
Fix for Cofunction self-assignment via interpolation (#3939)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: David A. Ham <[email protected]>
  • Loading branch information
jrmaddison and dham authored Feb 4, 2025
1 parent c2500aa commit ecac12b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
7 changes: 6 additions & 1 deletion firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,12 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **
V = self.V
result = output or firedrake.Function(V)
with function.dat.vec_ro as x, result.dat.vec_wo as out:
mul(x, out)
if x is not out:
mul(x, out)
else:
out_ = out.duplicate()
mul(x, out_)
out_.copy(result=out)
return result

else:
Expand Down
22 changes: 21 additions & 1 deletion tests/firedrake/regression/test_interp_dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,26 @@ def f1(mesh, V1):
return Function(V1).interpolate(expr)


def test_interp_self(V1):
a = assemble(conj(TestFunction(V1)) * dx)
b = assemble(conj(TestFunction(V1)) * dx)
a.interpolate(a)
assert np.allclose(a.dat.data_ro, b.dat.data_ro)


def test_assemble_interp_adjoint_tensor(mesh, V1, f1):
a = assemble(conj(TestFunction(V1)) * dx)
# We want tensor to be a dependency of the input expression for this test
assemble(Interpolator(f1 * TestFunction(V1), V1).interpolate(a, adjoint=True),
tensor=a)

x, y = SpatialCoordinate(mesh)
f2 = Function(V1, name="f2").interpolate(
exp(x) * y)

assert np.allclose(assemble(a(f2)), assemble(Function(V1).interpolate(conj(f1 * f2)) * dx))


def test_assemble_interp_operator(V2, f1):
# Check type
If1 = Interpolate(f1, V2)
Expand Down Expand Up @@ -106,7 +126,7 @@ def test_assemble_interp_adjoint_complex(mesh, V1, V2, f1):
f1 = Constant(3 - 5.j) * f1

a = assemble(conj(TestFunction(V1)) * dx)
b = Interpolator(f1 * TestFunction(V2), V1).interpolate(a, adjoint=True)
b = assemble(Interpolator(f1 * TestFunction(V2), V1).interpolate(a, adjoint=True))

x, y = SpatialCoordinate(mesh)
f2 = Function(V2, name="f2").interpolate(
Expand Down

0 comments on commit ecac12b

Please sign in to comment.