Skip to content

Commit

Permalink
Fix fieldsplit with Cofunction right hand side. (#3932)
Browse files Browse the repository at this point in the history
* Add cofunction handler to form splitter to enable fieldsplit with cofunctions.

* Apply suggestions from code review

Co-authored-by: Pablo Brubeck <[email protected]>

* raise error if ExtractSubBlock.split doesn't give back a form-like type

* higher level imports

* use dual space for splitting a cofunction

* test matrix-free fieldsplit with cofunction rhs

* review updates

* create sub dual space directly for split cofunction

* split logic via isinstance not TypeError

* test splitting a 2-form with off-diagonal blocks and a cofunction rhs

---------

Co-authored-by: Pablo Brubeck <[email protected]>
  • Loading branch information
JHopeCollins and pbrubeck authored Jan 2, 2025
1 parent 7ed6ff0 commit bfb7a19
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 15 deletions.
80 changes: 68 additions & 12 deletions firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
import numpy
import collections

from ufl import as_vector
from ufl.classes import Zero, FixedIndex, ListTensor
from ufl import as_vector, FormSum, Form, split
from ufl.classes import Zero, FixedIndex, ListTensor, ZeroBaseForm
from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.corealg.map_dag import MultiFunction, map_expr_dags

from pyop2 import MixedDat

from firedrake.petsc import PETSc
from firedrake.ufl_expr import Argument
from firedrake.cofunction import Cofunction
from firedrake.functionspace import FunctionSpace, MixedFunctionSpace, DualSpace


class ExtractSubBlock(MultiFunction):
Expand Down Expand Up @@ -85,9 +89,8 @@ def coefficient_derivative(self, o, expr, coefficients, arguments, cds):

@PETSc.Log.EventDecorator()
def argument(self, o):
from ufl import split
from firedrake import MixedFunctionSpace, FunctionSpace
V = o.function_space()

if len(V) == 1:
# Not on a mixed space, just return ourselves.
return o
Expand All @@ -98,15 +101,11 @@ def argument(self, o):
V_is = V.subfunctions
indices = self.blocks[o.number()]

try:
indices = tuple(indices)
nidx = len(indices)
except TypeError:
# Only one index provided.
# Only one index provided.
if isinstance(indices, int):
indices = (indices, )
nidx = 1

if nidx == 1:
if len(indices) == 1:
W = V_is[indices[0]]
W = FunctionSpace(W.mesh(), W.ufl_element())
a = (Argument(W, o.number(), part=o.part()), )
Expand All @@ -127,6 +126,50 @@ def argument(self, o):
for j in numpy.ndindex(V_is[i].value_shape)]
return self._arg_cache.setdefault(o, as_vector(args))

def cofunction(self, o):
V = o.function_space()

# Not on a mixed space, just return ourselves.
if len(V) == 1:
return o

# We only need the test space for Cofunction
indices = self.blocks[0]
V_is = V.subfunctions

# Only one index provided.
if isinstance(indices, int):
indices = (indices, )

# for two-forms, the cofunction should only
# be returned for the diagonal blocks, so
# if we are asked for an off-diagonal block
# then we return a zero form, analogously to
# the off components of arguments.
if len(self.blocks) == 2:
itest, itrial = self.blocks
on_diag = (itest == itrial)
else:
on_diag = True

# if we are on the diagonal, then return a Cofunction
# in the relevant subspace that points to the data in
# the full space. This means that the right hand side
# of the fieldsplit problem will be correct.
if on_diag:
if len(indices) == 1:
i = indices[0]
W = V_is[i]
W = DualSpace(W.mesh(), W.ufl_element())
c = Cofunction(W, val=o.subfunctions[i].dat)
else:
W = MixedFunctionSpace([V_is[i] for i in indices])
c = Cofunction(W, val=MixedDat(o.dat[i] for i in indices))
else:
c = ZeroBaseForm(o.arguments())

return c


SplitForm = collections.namedtuple("SplitForm", ["indices", "form"])

Expand Down Expand Up @@ -168,7 +211,20 @@ def split_form(form, diagonal=False):
assert len(shape) == 2
for idx in numpy.ndindex(shape):
f = splitter.split(form, idx)
if len(f.integrals()) > 0:

# does f actually contain anything?
if isinstance(f, Cofunction):
flen = 1
elif isinstance(f, FormSum):
flen = len(f.components())
elif isinstance(f, Form):
flen = len(f.integrals())
else:
raise ValueError(
"ExtractSubBlock.split should have returned an instance of "
"either Form, FormSum, or Cofunction")

if flen > 0:
if diagonal:
i, j = idx
if i != j:
Expand Down
7 changes: 6 additions & 1 deletion tests/firedrake/regression/test_linesmoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def backend(request):
return request.param


def test_linesmoother(mesh, S1family, expected, backend):
@pytest.mark.parametrize("rhs", ["form_rhs", "cofunc_rhs"])
def test_linesmoother(mesh, S1family, expected, backend, rhs):
base_cell = mesh._base_mesh.ufl_cell()
S2family = "DG" if base_cell.is_simplex() else "DQ"
DGfamily = "DG" if mesh.ufl_cell().is_simplex() else "DQ"
Expand Down Expand Up @@ -86,6 +87,10 @@ def test_linesmoother(mesh, S1family, expected, backend):
f = exp(-rsq)

L = inner(f, q)*dx(degree=2*(degree+1))
if rhs == 'cofunc_rhs':
L = assemble(L)
elif rhs != 'form_rhs':
raise ValueError("Unknown right hand side type")

w0 = Function(W)
problem = LinearVariationalProblem(a, L, w0, bcs=bcs, aP=aP, form_compiler_parameters={"mode": "vanilla"})
Expand Down
7 changes: 6 additions & 1 deletion tests/firedrake/regression/test_matrix_free.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def test_matrixfree_action(a, V, bcs):

@pytest.mark.parametrize("preassembled", [False, True],
ids=["variational", "preassembled"])
@pytest.mark.parametrize("rhs", ["form_rhs", "cofunc_rhs"])
@pytest.mark.parametrize("parameters",
[{"ksp_type": "preonly",
"pc_type": "python",
Expand Down Expand Up @@ -168,7 +169,7 @@ def test_matrixfree_action(a, V, bcs):
"fieldsplit_1_fieldsplit_1_pc_type": "python",
"fieldsplit_1_fieldsplit_1_pc_python_type": "firedrake.AssembledPC",
"fieldsplit_1_fieldsplit_1_assembled_pc_type": "lu"}])
def test_fieldsplitting(mesh, preassembled, parameters):
def test_fieldsplitting(mesh, preassembled, parameters, rhs):
V = FunctionSpace(mesh, "CG", 1)
P = FunctionSpace(mesh, "DG", 0)
Q = VectorFunctionSpace(mesh, "DG", 1)
Expand All @@ -185,6 +186,10 @@ def test_fieldsplitting(mesh, preassembled, parameters):
a = inner(u, v)*dx

L = inner(expect, v)*dx
if rhs == 'cofunc_rhs':
L = assemble(L)
elif rhs != 'form_rhs':
raise ValueError("Unknown right hand side type")

f = Function(W)

Expand Down
7 changes: 6 additions & 1 deletion tests/firedrake/regression/test_nullspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,8 @@ def test_nullspace_mixed_multiple_components():

@pytest.mark.parallel(nprocs=2)
@pytest.mark.parametrize("aux_pc", [False, True], ids=["PC(mu)", "PC(DG0-mu)"])
def test_near_nullspace_mixed(aux_pc):
@pytest.mark.parametrize("rhs", ["form_rhs", "cofunc_rhs"])
def test_near_nullspace_mixed(aux_pc, rhs):
# test nullspace and nearnullspace for a mixed Stokes system
# this is tested on the SINKER case of May and Moresi https://doi.org/10.1016/j.pepi.2008.07.036
# fails in parallel if nullspace is copied to fieldsplit_1_Mp_ksp solve (see PR #3488)
Expand Down Expand Up @@ -323,6 +324,10 @@ def test_near_nullspace_mixed(aux_pc):

f = as_vector((0, -9.8*conditional(inside_box, 2, 1)))
L = inner(f, v)*dx
if rhs == 'cofunc_rhs':
L = assemble(L)
elif rhs != 'form_rhs':
raise ValueError("Unknown right hand side type")

bcs = [DirichletBC(W[0].sub(0), 0, (1, 2)), DirichletBC(W[0].sub(1), 0, (3, 4))]

Expand Down

0 comments on commit bfb7a19

Please sign in to comment.