Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make LinearSolver a subclass of LinearVariationalSolver #4012

Open
wants to merge 34 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
956bb72
FEniCS-style bcs
pbrubeck Jan 27, 2025
20985af
Delay Dirichlet Lifting for LVP, MG, and Fieldsplit
pbrubeck Jan 29, 2025
89fdfef
LinearSolver: support pre_apply_bcs
pbrubeck Feb 3, 2025
8f30e7e
Merge branch 'master' into pbrubeck/feature/fenics-bcs
pbrubeck Feb 4, 2025
6ab5071
Update firedrake/assemble.py
pbrubeck Feb 5, 2025
e4c0249
fixup
pbrubeck Feb 5, 2025
29fa679
Merge branch 'pbrubeck/feature/fenics-bcs' of github.com:firedrakepro…
pbrubeck Feb 5, 2025
f4514a3
address review comments
pbrubeck Feb 5, 2025
59fc6ed
Make LinearSolver a wrapper of LinearVariationalSolver
pbrubeck Feb 5, 2025
9d34989
Fix tests
pbrubeck Feb 5, 2025
400ae30
Allow combinations of slate.Tensors and ufl.BaseForm
pbrubeck Feb 5, 2025
7558e79
Only extract form if the assembled matrix has bcs
pbrubeck Feb 6, 2025
e761c6f
Fieldsplit: handle MatrixBase
pbrubeck Feb 6, 2025
eb796b6
Fieldsplit: handle bc lifting for linear problem
pbrubeck Feb 6, 2025
e695cd2
Fix tests
pbrubeck Feb 6, 2025
254c83d
Apply suggestions from code review
pbrubeck Feb 6, 2025
8cc1d5f
Update firedrake/linear_solver.py
pbrubeck Feb 7, 2025
380def5
Merge branch 'master' into pbrubeck/linear-solver
pbrubeck Feb 7, 2025
fd4e537
Slate: syntax sugar
pbrubeck Feb 7, 2025
4abd841
Merge branch 'master' into pbrubeck/feature/fenics-bcs
pbrubeck Feb 7, 2025
9a022bc
Merge branch 'pbrubeck/feature/fenics-bcs' into pbrubeck/linear-solver
pbrubeck Feb 7, 2025
67cebb1
Projector: deprecate quadrature_degree kwarg
pbrubeck Feb 7, 2025
45fffe5
solve: stop supporting Vector
pbrubeck Feb 7, 2025
cfe094b
LinearSolver: form_compiler_parameters
pbrubeck Feb 7, 2025
ba4a055
MatrixBase: attach fc_params
pbrubeck Feb 7, 2025
da432b1
fix up
pbrubeck Feb 7, 2025
21db7ff
Update firedrake/preconditioners/base.py
pbrubeck Feb 7, 2025
dff1d67
whitespace
pbrubeck Feb 7, 2025
7958933
Fix broken stokes test
pbrubeck Feb 7, 2025
53ebdf1
Merge branch 'master' into pbrubeck/feature/fenics-bcs
pbrubeck Feb 12, 2025
be79f04
merge conflict
pbrubeck Feb 12, 2025
a6afb91
Test solve(assembled(a, bcs=bcs) == L, u)
pbrubeck Feb 12, 2025
85806a2
Test slate casting
pbrubeck Feb 12, 2025
d15720d
merge conflict
pbrubeck Feb 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions firedrake/adjoint_utils/blocks/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,10 @@ def _assemble_dFdu_adj(self, dFdu_adj_form, **kwargs):

def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):
dJdu_copy = dJdu.copy()
# Homogenize and apply boundary conditions on adj_dFdu and dJdu.
# Homogenize and apply boundary conditions on adj_dFdu.
bcs = self._homogenize_bcs()
dFdu = firedrake.assemble(dFdu_adj_form, bcs=bcs, **self.assemble_kwargs)

for bc in bcs:
bc.zero(dJdu)

adj_sol = firedrake.Function(self.function_space)
firedrake.solve(
dFdu, adj_sol, dJdu, *self.adj_args, **self.adj_kwargs
Expand Down Expand Up @@ -526,18 +523,11 @@ def _forward_solve(self, lhs, rhs, func, bcs):
return func

def _assembled_solve(self, lhs, rhs, func, bcs, **kwargs):
rhs_func = rhs.riesz_representation(riesz_map="l2")
for bc in bcs:
bc.apply(rhs_func)
rhs.assign(rhs_func.riesz_representation(riesz_map="l2"))
firedrake.solve(lhs, func, rhs, **kwargs)
return func

def recompute_component(self, inputs, block_variable, idx, prepared):
lhs = prepared[0]
rhs = prepared[1]
func = prepared[2]
bcs = prepared[3]
lhs, rhs, func, bcs = prepared
result = self._forward_solve(lhs, rhs, func, bcs)
if isinstance(block_variable.checkpoint, firedrake.Function):
result = block_variable.checkpoint.assign(result)
Expand Down
81 changes: 54 additions & 27 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def assemble(expr, *args, **kwargs):
`matrix.Matrix`.
is_base_form_preprocessed : bool
If `True`, skip preprocessing of the form.
current_state : firedrake.function.Function or None
If provided and ``zero_bc_nodes == False``, the boundary condition
nodes of the output are set to the residual of the boundary conditions
computed as ``current_state`` minus the boundary condition value.

Returns
-------
Expand Down Expand Up @@ -130,16 +134,21 @@ def assemble(expr, *args, **kwargs):
"""
if args:
raise RuntimeError(f"Got unexpected args: {args}")
tensor = kwargs.pop("tensor", None)
return get_assembler(expr, *args, **kwargs).assemble(tensor=tensor)

assemble_kwargs = {}
for key in ("tensor", "current_state"):
if key in kwargs:
assemble_kwargs[key] = kwargs.pop(key, None)
return get_assembler(expr, *args, **kwargs).assemble(**assemble_kwargs)


def get_assembler(form, *args, **kwargs):
"""Create an assembler.

Notes
-----
See `assemble` for descriptions of the parameters. ``tensor`` should not be passed to this function.
See `assemble` for descriptions of the parameters. ``tensor`` and
``current_state`` should not be passed to this function.

"""
is_base_form_preprocessed = kwargs.pop('is_base_form_preprocessed', False)
Expand Down Expand Up @@ -187,13 +196,15 @@ class ExprAssembler(object):
def __init__(self, expr):
self._expr = expr

def assemble(self, tensor=None):
def assemble(self, tensor=None, current_state=None):
"""Assemble the pointwise expression.

Parameters
----------
tensor : firedrake.function.Function or firedrake.cofunction.Cofunction or matrix.MatrixBase
Output tensor.
current_state : None
Ignored by this class.

Returns
-------
Expand All @@ -205,6 +216,7 @@ def assemble(self, tensor=None):
from ufl.checks import is_scalar_constant_expression

assert tensor is None
assert current_state is None
expr = self._expr
# Get BaseFormOperators (e.g. `Interpolate` or `ExternalOperator`)
base_form_operators = extract_base_form_operators(expr)
Expand Down Expand Up @@ -274,13 +286,16 @@ def allocate(self):
"""Allocate memory for the output tensor."""

@abc.abstractmethod
def assemble(self, tensor=None):
def assemble(self, tensor=None, current_state=None):
"""Assemble the form.

Parameters
----------
tensor : firedrake.cofunction.Cofunction or firedrake.function.Function or matrix.MatrixBase
Output tensor to contain the result of assembly; if `None`, a tensor of appropriate type is created.
current_state : firedrake.function.Function or None
If provided, the boundary condition nodes are set to the boundary condition residual
computed as ``current_state`` minus the boundary condition value.

Returns
-------
Expand Down Expand Up @@ -329,7 +344,9 @@ def __init__(self,
def allocate(self):
rank = len(self._form.arguments())
if rank == 2 and not self._diagonal:
if self._mat_type == "matfree":
if isinstance(self._form, matrix.MatrixBase):
return self._form
elif self._mat_type == "matfree":
return MatrixFreeAssembler(self._form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params,
options_prefix=self._options_prefix,
appctx=self._appctx).allocate()
Expand Down Expand Up @@ -358,13 +375,16 @@ def allocation_integral_types(self):
else:
return self._allocation_integral_types

def assemble(self, tensor=None):
def assemble(self, tensor=None, current_state=None):
"""Assemble the form.

Parameters
----------
tensor : firedrake.cofunction.Cofunction or firedrake.function.Function or matrix.MatrixBase
Output tensor to contain the result of assembly.
current_state : firedrake.function.Function or None
If provided, the boundary condition nodes are set to the boundary condition residual
computed as ``current_state`` minus the boundary condition value.

Returns
-------
Expand All @@ -389,7 +409,7 @@ def visitor(e, *operands):
rank = len(self._form.arguments())
if rank == 1 and not isinstance(result, ufl.ZeroBaseForm):
for bc in self._bcs:
bc.zero(result)
OneFormAssembler._apply_bc(self, result, bc, u=current_state)

if tensor:
BaseFormAssembler.update_tensor(result, tensor)
Expand Down Expand Up @@ -968,13 +988,16 @@ def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing=
super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters)
self._needs_zeroing = needs_zeroing

def assemble(self, tensor=None):
def assemble(self, tensor=None, current_state=None):
"""Assemble the form.

Parameters
----------
tensor : firedrake.cofunction.Cofunction or matrix.MatrixBase
Output tensor to contain the result of assembly; if `None`, a tensor of appropriate type is created.
current_state : firedrake.function.Function or None
If provided, the boundary condition nodes are set to the boundary condition residual
computed as ``current_state`` minus the boundary condition value.

Returns
-------
Expand All @@ -998,12 +1021,12 @@ def assemble(self, tensor=None):
self.execute_parloops(tensor)

for bc in self._bcs:
self._apply_bc(tensor, bc)
self._apply_bc(tensor, bc, u=current_state)

return self.result(tensor)

@abc.abstractmethod
def _apply_bc(self, tensor, bc):
def _apply_bc(self, tensor, bc, u=None):
"""Apply boundary condition."""

@abc.abstractmethod
Expand Down Expand Up @@ -1138,7 +1161,7 @@ def allocate(self):
comm=self._form.ufl_domains()[0]._comm
)

def _apply_bc(self, tensor, bc):
def _apply_bc(self, tensor, bc, u=None):
pass

def _check_tensor(self, tensor):
Expand Down Expand Up @@ -1199,26 +1222,29 @@ def allocate(self):
else:
raise RuntimeError(f"Not expected: found rank = {rank} and diagonal = {self._diagonal}")

def _apply_bc(self, tensor, bc):
def _apply_bc(self, tensor, bc, u=None):
# TODO Maybe this could be a singledispatchmethod?
if isinstance(bc, DirichletBC):
self._apply_dirichlet_bc(tensor, bc)
if self._diagonal:
bc.set(tensor, self._weight)
elif self._zero_bc_nodes:
bc.zero(tensor)
else:
# The residual belongs to a mixed space that is dual on the boundary nodes
# and primal on the interior nodes. Therefore, this is a type-safe operation.
r = tensor.riesz_representation("l2")
bc.apply(r, u=u)
elif isinstance(bc, EquationBCSplit):
bc.zero(tensor)
type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight).assemble(tensor=tensor)
OneFormAssembler(bc.f, bcs=bc.bcs,
form_compiler_parameters=self._form_compiler_params,
needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes,
diagonal=self._diagonal,
weight=self._weight).assemble(tensor=tensor, current_state=u)
else:
raise AssertionError

def _apply_dirichlet_bc(self, tensor, bc):
if self._diagonal:
bc.set(tensor, self._weight)
elif not self._zero_bc_nodes:
# NOTE this only works if tensor is a Function and not a Cofunction
bc.apply(tensor)
else:
bc.zero(tensor)

def _check_tensor(self, tensor):
if tensor.function_space() != self._form.arguments()[0].function_space().dual():
raise ValueError("Form's argument does not match provided result tensor")
Expand Down Expand Up @@ -1430,7 +1456,8 @@ def _all_assemblers(self):
all_assemblers.extend(_assembler._all_assemblers)
return tuple(all_assemblers)

def _apply_bc(self, tensor, bc):
def _apply_bc(self, tensor, bc, u=None):
assert u is None
op2tensor = tensor.M
spaces = tuple(a.function_space() for a in tensor.a.arguments())
V = bc.function_space()
Expand Down Expand Up @@ -1534,7 +1561,7 @@ def allocate(self):
options_prefix=self._options_prefix,
appctx=self._appctx or {})

def assemble(self, tensor=None):
def assemble(self, tensor=None, current_state=None):
if tensor is None:
tensor = self.allocate()
else:
Expand Down
28 changes: 14 additions & 14 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def extract_form(self, form_type):
# DirichletBC is directly used in assembly.
return self

def _as_nonlinear_variational_problem_arg(self):
def _as_nonlinear_variational_problem_arg(self, is_linear=False):
return self


Expand Down Expand Up @@ -501,15 +501,16 @@ def __init__(self, *args, bcs=None, J=None, Jp=None, V=None, is_linear=False, Jp
# linear
if isinstance(eq.lhs, ufl.Form) and isinstance(eq.rhs, ufl.Form):
J = eq.lhs
L = eq.rhs
Jp = Jp or J
if eq.rhs == 0:
if L == 0 or L.empty():
F = ufl_expr.action(J, u)
else:
if not isinstance(eq.rhs, (ufl.Form, slate.slate.TensorBase)):
raise TypeError("Provided BC RHS is a '%s', not a Form or Slate Tensor" % type(eq.rhs).__name__)
if len(eq.rhs.arguments()) != 1:
if not isinstance(L, (ufl.BaseForm, slate.slate.TensorBase)):
raise TypeError("Provided BC RHS is a '%s', not a BaseForm or Slate Tensor" % type(L).__name__)
if len(L.arguments()) != 1:
raise ValueError("Provided BC RHS is not a linear form")
F = ufl_expr.action(J, u) - eq.rhs
F = ufl_expr.action(J, u) - L
self.is_linear = True
# nonlinear
else:
Expand All @@ -531,9 +532,7 @@ def __init__(self, *args, bcs=None, J=None, Jp=None, V=None, is_linear=False, Jp
# reconstruction for splitting `solving_utils.split`
self.Jp_eq_J = Jp_eq_J
self.is_linear = is_linear
self._F = args[0]
self._J = args[1]
self._Jp = args[2]
self._F, self._J, self._Jp = args[:3]
else:
raise TypeError("Wrong EquationBC arguments")

Expand Down Expand Up @@ -562,7 +561,7 @@ def reconstruct(self, V, subu, u, field, is_linear):
if all([_F is not None, _J is not None, _Jp is not None]):
return EquationBC(_F, _J, _Jp, Jp_eq_J=self.Jp_eq_J, is_linear=is_linear)

def _as_nonlinear_variational_problem_arg(self):
def _as_nonlinear_variational_problem_arg(self, is_linear=False):
return self


Expand Down Expand Up @@ -654,19 +653,20 @@ def reconstruct(self, field=None, V=None, subu=None, u=None, row_field=None, col
ebc.add(bc_temp)
return ebc

def _as_nonlinear_variational_problem_arg(self):
def _as_nonlinear_variational_problem_arg(self, is_linear=False):
# NonlinearVariationalProblem expects EquationBC, not EquationBCSplit.
# -- This method is required when NonlinearVariationalProblem is constructed inside PC.
if len(self.f.arguments()) != 2:
raise NotImplementedError(f"Not expecting a form of rank {len(self.f.arguments())} (!= 2)")
J = self.f
Vcol = J.arguments()[-1].function_space()
u = firedrake.Function(Vcol)
F = ufl_expr.action(J, u)
Vrow = self._function_space
sub_domain = self.sub_domain
bcs = tuple(bc._as_nonlinear_variational_problem_arg() for bc in self.bcs)
return EquationBC(F == 0, u, sub_domain, bcs=bcs, J=J, V=Vrow)
bcs = tuple(bc._as_nonlinear_variational_problem_arg(is_linear=is_linear) for bc in self.bcs)
lhs = J if is_linear else ufl_expr.action(J, u)
rhs = ufl.Form([]) if is_linear else 0
return EquationBC(lhs == rhs, u, sub_domain, bcs=bcs, J=J, V=Vrow)


@PETSc.Log.EventDecorator()
Expand Down
24 changes: 24 additions & 0 deletions firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from firedrake.petsc import PETSc
from firedrake.functionspace import MixedFunctionSpace
from firedrake.cofunction import Cofunction
from firedrake.matrix import AssembledMatrix


def subspace(V, indices):
Expand Down Expand Up @@ -147,6 +148,29 @@ def cofunction(self, o):
else:
return Cofunction(W, val=MixedDat(o.dat[i] for i in indices))

def matrix(self, o):
ises = []
args = []
for a in o.arguments():
V = a.function_space()
iset = PETSc.IS()
if a.number() in self.blocks:
asplit = self._subspace_argument(a)
for f in self.blocks[a.number()]:
fset = V.dof_dset.field_ises[f]
iset = iset.expand(fset)
else:
asplit = a
for fset in V.dof_dset.field_ises:
iset = iset.expand(fset)

ises.append(iset)
args.append(asplit)

submat = o.petscmat.createSubMatrix(*ises)
bcs = ()
return AssembledMatrix(tuple(args), bcs, submat)
JHopeCollins marked this conversation as resolved.
Show resolved Hide resolved


SplitForm = collections.namedtuple("SplitForm", ["indices", "form"])

Expand Down
Loading
Loading