Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fieldsplit: replace empty Forms with ZeroBaseForm #3947

Merged
merged 42 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
a86614a
Restricted Cofunction RHS
pbrubeck Dec 11, 2024
aef7886
Fix BCs on Cofunction
pbrubeck Dec 11, 2024
8e5603e
LinearSolver: check function spaces
pbrubeck Dec 11, 2024
2dd4f76
assemble(form, zero_bc_nodes=True) as default
pbrubeck Dec 12, 2024
3c5e64f
Fix FunctionAssignBlock
pbrubeck Dec 12, 2024
e3449f5
Allow Cofunction.assign take in constants
pbrubeck Dec 12, 2024
b755c81
Merge branch 'master' into pbrubeck/fix/restricted-cofunction
pbrubeck Dec 13, 2024
0b0296d
Merge branch 'pbrubeck/fix/restricted-cofunction' of github.com:fired…
pbrubeck Dec 13, 2024
40cf6d7
suggestion from code review
pbrubeck Dec 13, 2024
fe30b48
more suggestions from review
pbrubeck Dec 19, 2024
3d49f31
remove BaseFormAssembler test
pbrubeck Dec 19, 2024
6742374
Merge branch 'master' into pbrubeck/fix/restricted-cofunction
pbrubeck Dec 20, 2024
df04f4b
only supply relevant kwargs to OneFormAssembler
pbrubeck Dec 20, 2024
474edb3
Merge branch 'master' into pbrubeck/fix/restricted-cofunction
pbrubeck Dec 20, 2024
950e42d
Only interpolate the residual, not every cofunction in the RHS
pbrubeck Dec 21, 2024
a86f3f5
DROP BEFORE MERGE
pbrubeck Dec 21, 2024
337d087
Fix tests
pbrubeck Dec 21, 2024
ed34164
Fix adjoint utils
pbrubeck Dec 22, 2024
027ad37
More robust test for (unrestricted) Cofunction RHS
pbrubeck Dec 22, 2024
885958f
Merge branch 'master' into pbrubeck/fix/restricted-cofunction
pbrubeck Dec 23, 2024
2286596
DO NOT MERGE
pbrubeck Dec 30, 2024
bb04bb0
Replace empty Jacobians with ZeroBaseForm
pbrubeck Jan 1, 2025
d82039d
Split Cofunction
pbrubeck Jan 2, 2025
af53302
Do not split off-diagonal blocks if we only want the diagonal
pbrubeck Jan 2, 2025
7f40504
Zero-simplify slate Tensors
pbrubeck Jan 3, 2025
b48c77c
Merge branch 'master' into pbrubeck/fix/restricted-cofunction
pbrubeck Jan 3, 2025
d68113f
set bcs directly on diagonal Cofunction
pbrubeck Jan 3, 2025
3d06fc5
ImplicitMatrixContext: handle empty action
pbrubeck Jan 3, 2025
6078f93
Only extract constants referenced in the kernel
pbrubeck Jan 4, 2025
5894b49
Adjoint: only skip expand_derivatives if necessary
pbrubeck Jan 4, 2025
d99ba50
style
pbrubeck Jan 4, 2025
d6bb7dd
EquationBC: do not reconstruct empty Forms
pbrubeck Jan 5, 2025
ed58467
lower degree for EquationBC tests
pbrubeck Jan 6, 2025
2a0c03b
style
pbrubeck Jan 6, 2025
934ff6f
FunctionSpace: multiindex returns subspace
pbrubeck Jan 9, 2025
8688f5e
Revert WithGeometry.__getitem__
pbrubeck Jan 10, 2025
7c2354e
Merge branch 'master' into pbrubeck/simplify-indexed
pbrubeck Jan 15, 2025
e99ce9a
Merge branch 'master' into pbrubeck/fix/restricted-cofunction
pbrubeck Jan 15, 2025
70a45fd
Merge branch 'master' into pbrubeck/simplify-indexed
pbrubeck Jan 15, 2025
605e52f
DROP BEFORE MERGE (2)
pbrubeck Jan 15, 2025
f3c4ef6
Do not zero a ZeroBaseForm
pbrubeck Jan 15, 2025
8f7ca9b
Update .github/workflows/build.yml
pbrubeck Jan 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ jobs:
--install defcon \
--install gadopt \
--install asQ \
--package-branch ufl pbrubeck/simplify-indexed \
|| (cat firedrake-install.log && /bin/false)
- name: Install test dependencies
run: |
Expand Down
13 changes: 8 additions & 5 deletions firedrake/adjoint_utils/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import wraps
from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape, no_annotations
from firedrake.adjoint_utils.blocks import NonlinearVariationalSolveBlock
from firedrake.ufl_expr import derivative, adjoint
from ufl import replace


Expand All @@ -11,7 +12,6 @@ def _ad_annotate_init(init):
@no_annotations
@wraps(init)
def wrapper(self, *args, **kwargs):
from firedrake import derivative, adjoint, TrialFunction
init(self, *args, **kwargs)
self._ad_F = self.F
self._ad_u = self.u_restrict
Expand All @@ -20,10 +20,13 @@ def wrapper(self, *args, **kwargs):
try:
# Some forms (e.g. SLATE tensors) are not currently
# differentiable.
dFdu = derivative(self.F,
self.u_restrict,
TrialFunction(self.u_restrict.function_space()))
self._ad_adj_F = adjoint(dFdu)
dFdu = derivative(self.F, self.u_restrict)
try:
self._ad_adj_F = adjoint(dFdu)
except ValueError:
# Try again without expanding derivatives,
# as dFdu might have been simplied to an empty Form
self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True)
pbrubeck marked this conversation as resolved.
Show resolved Hide resolved
except (TypeError, NotImplementedError):
self._ad_adj_F = None
self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear}
Expand Down
22 changes: 13 additions & 9 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,10 +577,15 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
@staticmethod
def update_tensor(assembled_base_form, tensor):
if isinstance(tensor, (firedrake.Function, firedrake.Cofunction)):
assembled_base_form.dat.copy(tensor.dat)
if isinstance(assembled_base_form, ufl.ZeroBaseForm):
tensor.dat.zero()
else:
assembled_base_form.dat.copy(tensor.dat)
elif isinstance(tensor, matrix.MatrixBase):
# Uses the PETSc copy method.
assembled_base_form.petscmat.copy(tensor.petscmat)
if isinstance(assembled_base_form, ufl.ZeroBaseForm):
tensor.petscmat.zeroEntries()
else:
assembled_base_form.petscmat.copy(tensor.petscmat)
else:
raise NotImplementedError("Cannot update tensor of type %s" % type(tensor))

Expand Down Expand Up @@ -1138,7 +1143,7 @@ class OneFormAssembler(ParloopFormAssembler):

Parameters
----------
form : ufl.Form or slate.TensorBasehe
form : ufl.Form or slate.TensorBase
1-form.

Notes
Expand Down Expand Up @@ -2127,14 +2132,13 @@ def iter_active_coefficients(form, kinfo):

@staticmethod
def iter_constants(form, kinfo):
"""Yield the form constants"""
"""Yield the form constants referenced in ``kinfo``."""
if isinstance(form, slate.TensorBase):
for const in form.constants():
yield const
all_constants = form.constants()
else:
all_constants = extract_firedrake_constants(form)
for constant_index in kinfo.constant_numbers:
yield all_constants[constant_index]
for constant_index in kinfo.constant_numbers:
yield all_constants[constant_index]

@staticmethod
def index_function_spaces(form, indices):
Expand Down
8 changes: 4 additions & 4 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,10 +634,10 @@ def reconstruct(self, field=None, V=None, subu=None, u=None, row_field=None, col
return
rank = len(self.f.arguments())
splitter = ExtractSubBlock()
if rank == 1:
form = splitter.split(self.f, argument_indices=(row_field, ))
elif rank == 2:
form = splitter.split(self.f, argument_indices=(row_field, col_field))
form = splitter.split(self.f, argument_indices=(row_field, col_field)[:rank])
if isinstance(form, ufl.ZeroBaseForm) or form.empty():
# form is empty, do nothing
return
if u is not None:
form = firedrake.replace(form, {self.u: u})
if action_x is not None:
Expand Down
123 changes: 48 additions & 75 deletions firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,30 @@
import numpy
import collections

from ufl import as_vector, FormSum, Form, split
from ufl import as_vector, split
from ufl.classes import Zero, FixedIndex, ListTensor, ZeroBaseForm
from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.algorithms import expand_derivatives
from ufl.corealg.map_dag import MultiFunction, map_expr_dags

from pyop2 import MixedDat
from pyop2.utils import as_tuple

from firedrake.petsc import PETSc
from firedrake.ufl_expr import Argument
from firedrake.cofunction import Cofunction
from firedrake.functionspace import FunctionSpace, MixedFunctionSpace, DualSpace


def subspace(V, indices):
if len(indices) == 1:
W = V[indices[0]]
W = FunctionSpace(W.mesh(), W.ufl_element())
else:
W = MixedFunctionSpace([V[i] for i in indices])
return W


class ExtractSubBlock(MultiFunction):

"""Extract a sub-block from a form."""
Expand All @@ -30,9 +41,11 @@ def indexed(self, o, child, multiindex):
indices = multiindex.indices()
if isinstance(child, ListTensor) and all(isinstance(i, FixedIndex) for i in indices):
if len(indices) == 1:
return child.ufl_operands[indices[0]._value]
return child[indices[0]]
elif len(indices) == len(child.ufl_operands) and all(k == int(i) for k, i in enumerate(indices)):
return child
else:
return ListTensor(*(child.ufl_operands[i._value] for i in multiindex.indices()))
return ListTensor(*(child[i] for i in indices))
return self.expr(o, child, multiindex)

index_inliner = IndexInliner()
Expand All @@ -52,15 +65,22 @@ def split(self, form, argument_indices):
"""
args = form.arguments()
self._arg_cache = {}
self.blocks = dict(enumerate(argument_indices))
self.blocks = dict(enumerate(map(as_tuple, argument_indices)))
if len(args) == 0:
# Functional can't be split
return form
if all(len(a.function_space()) == 1 for a in args):
assert (len(idx) == 1 for idx in self.blocks.values())
assert (idx[0] == 0 for idx in self.blocks.values())
return form
# TODO find a way to distinguish empty Forms avoiding expand_derivatives
ksagiyam marked this conversation as resolved.
Show resolved Hide resolved
f = map_integrand_dags(self, form)
if expand_derivatives(f).empty():
# Get ZeroBaseForm with the right shape
f = ZeroBaseForm(tuple(Argument(subspace(arg.function_space(),
self.blocks[arg.number()]),
arg.number(), part=arg.part())
for arg in form.arguments()))
return f

expr = MultiFunction.reuse_if_untouched
Expand Down Expand Up @@ -98,76 +118,42 @@ def argument(self, o):
if o in self._arg_cache:
return self._arg_cache[o]

V_is = V.subfunctions
indices = self.blocks[o.number()]

# Only one index provided.
if isinstance(indices, int):
indices = (indices, )
W = subspace(V, indices)
a = Argument(W, o.number(), part=o.part())
a = (a, ) if len(W) == 1 else split(a)

if len(indices) == 1:
W = V_is[indices[0]]
W = FunctionSpace(W.mesh(), W.ufl_element())
a = (Argument(W, o.number(), part=o.part()), )
else:
W = MixedFunctionSpace([V_is[i] for i in indices])
a = split(Argument(W, o.number(), part=o.part()))
args = []
for i in range(len(V_is)):
for i in range(len(V)):
if i in indices:
c = indices.index(i)
a_ = a[c]
if len(a_.ufl_shape) == 0:
args += [a_]
args.append(a_)
else:
args += [a_[j] for j in numpy.ndindex(a_.ufl_shape)]
args.extend(a_[j] for j in numpy.ndindex(a_.ufl_shape))
else:
args += [Zero()
for j in numpy.ndindex(V_is[i].value_shape)]
args.extend(Zero() for j in numpy.ndindex(V[i].value_shape))
return self._arg_cache.setdefault(o, as_vector(args))

def cofunction(self, o):
V = o.function_space()

# Not on a mixed space, just return ourselves.
if len(V) == 1:
# Not on a mixed space, just return ourselves.
return o

# We only need the test space for Cofunction
# We only need the test space for Cofunction
pbrubeck marked this conversation as resolved.
Show resolved Hide resolved
indices = self.blocks[0]
V_is = V.subfunctions

# Only one index provided.
if isinstance(indices, int):
indices = (indices, )

# for two-forms, the cofunction should only
# be returned for the diagonal blocks, so
# if we are asked for an off-diagonal block
# then we return a zero form, analogously to
# the off components of arguments.
if len(self.blocks) == 2:
itest, itrial = self.blocks
on_diag = (itest == itrial)
else:
on_diag = True

# if we are on the diagonal, then return a Cofunction
# in the relevant subspace that points to the data in
# the full space. This means that the right hand side
# of the fieldsplit problem will be correct.
if on_diag:
if len(indices) == 1:
i = indices[0]
W = V_is[i]
W = DualSpace(W.mesh(), W.ufl_element())
c = Cofunction(W, val=o.subfunctions[i].dat)
else:
W = MixedFunctionSpace([V_is[i] for i in indices])
c = Cofunction(W, val=MixedDat(o.dat[i] for i in indices))
if len(indices) == 1:
i = indices[0]
W = V[i]
W = DualSpace(W.mesh(), W.ufl_element())
pbrubeck marked this conversation as resolved.
Show resolved Hide resolved
c = Cofunction(W, val=o.dat[i])
else:
c = ZeroBaseForm(o.arguments())

W = MixedFunctionSpace([V[i] for i in indices])
c = Cofunction(W, val=MixedDat(o.dat[i] for i in indices))
return c


Expand Down Expand Up @@ -207,28 +193,15 @@ def split_form(form, diagonal=False):
args = form.arguments()
shape = tuple(len(a.function_space()) for a in args)
forms = []
rank = len(shape)
if diagonal:
assert len(shape) == 2
assert rank == 2
rank = 1
for idx in numpy.ndindex(shape):
if diagonal:
i, j = idx
if i != j:
continue
f = splitter.split(form, idx)

# does f actually contain anything?
if isinstance(f, Cofunction):
flen = 1
elif isinstance(f, FormSum):
flen = len(f.components())
elif isinstance(f, Form):
flen = len(f.integrals())
else:
raise ValueError(
"ExtractSubBlock.split should have returned an instance of "
"either Form, FormSum, or Cofunction")

if flen > 0:
if diagonal:
i, j = idx
if i != j:
continue
idx = (i, )
forms.append(SplitForm(indices=idx, form=f))
forms.append(SplitForm(indices=idx[:rank], form=f))
return tuple(forms)
28 changes: 17 additions & 11 deletions firedrake/matrix_free/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from firedrake.bcs import DirichletBC, EquationBCSplit
from firedrake.petsc import PETSc
from firedrake.utils import cached_property
from firedrake.function import Function
from firedrake.cofunction import Cofunction
from ufl.form import ZeroBaseForm


__all__ = ("ImplicitMatrixContext", )
Expand Down Expand Up @@ -107,23 +110,22 @@ def __init__(self, a, row_bcs=[], col_bcs=[],

# create functions from test and trial space to help
# with 1-form assembly
test_space, trial_space = [
a.arguments()[i].function_space() for i in (0, 1)
]
from firedrake import function, cofunction
test_space, trial_space = (
arg.function_space() for arg in a.arguments()
)
# Need a cofunction since y receives the assembled result of Ax
self._ystar = cofunction.Cofunction(test_space.dual())
self._y = function.Function(test_space)
self._x = function.Function(trial_space)
self._xstar = cofunction.Cofunction(trial_space.dual())
self._ystar = Cofunction(test_space.dual())
self._y = Function(test_space)
self._x = Function(trial_space)
self._xstar = Cofunction(trial_space.dual())

# These are temporary storage for holding the BC
# values during matvec application. _xbc is for
# the action and ._ybc is for transpose.
if len(self.bcs) > 0:
self._xbc = cofunction.Cofunction(trial_space.dual())
self._xbc = Cofunction(trial_space.dual())
if len(self.col_bcs) > 0:
self._ybc = cofunction.Cofunction(test_space.dual())
self._ybc = Cofunction(test_space.dual())

# Get size information from template vecs on test and trial spaces
trial_vec = trial_space.dof_dset.layout_vec
Expand All @@ -135,6 +137,11 @@ def __init__(self, a, row_bcs=[], col_bcs=[],

self.action = action(self.a, self._x)
self.actionT = action(self.aT, self._y)
# TODO prevent action from returning empty Forms
pbrubeck marked this conversation as resolved.
Show resolved Hide resolved
if self.action.empty():
self.action = ZeroBaseForm(self.a.arguments()[:-1])
if self.actionT.empty():
self.actionT = ZeroBaseForm(self.aT.arguments()[:-1])

# For assembling action(f, self._x)
self.bcs_action = []
Expand Down Expand Up @@ -170,7 +177,6 @@ def __init__(self, a, row_bcs=[], col_bcs=[],

@cached_property
def _diagonal(self):
from firedrake import Cofunction
assert self.on_diag
return Cofunction(self._x.function_space().dual())

Expand Down
2 changes: 1 addition & 1 deletion firedrake/preconditioners/massinv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class MassInvPC(AssembledPC):
context, keyed on ``"mu"``.
"""
def form(self, pc, test, trial):
_, bcs = super(MassInvPC, self).form(pc, test, trial)
_, bcs = super(MassInvPC, self).form(pc)

appctx = self.get_appctx(pc)
mu = appctx.get("mu", 1.0)
Expand Down
Loading
Loading