Skip to content

Commit

Permalink
LinearSolver, support pre_apply_bcs
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Feb 3, 2025
1 parent 20985af commit c86e896
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 38 deletions.
14 changes: 2 additions & 12 deletions firedrake/adjoint_utils/blocks/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,10 @@ def _assemble_dFdu_adj(self, dFdu_adj_form, **kwargs):

def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):
dJdu_copy = dJdu.copy()
# Homogenize and apply boundary conditions on adj_dFdu and dJdu.
# Homogenize and apply boundary conditions on adj_dFdu.
bcs = self._homogenize_bcs()
dFdu = firedrake.assemble(dFdu_adj_form, bcs=bcs, **self.assemble_kwargs)

for bc in bcs:
bc.zero(dJdu)

adj_sol = firedrake.Function(self.function_space)
firedrake.solve(
dFdu, adj_sol, dJdu, *self.adj_args, **self.adj_kwargs
Expand Down Expand Up @@ -526,18 +523,11 @@ def _forward_solve(self, lhs, rhs, func, bcs):
return func

def _assembled_solve(self, lhs, rhs, func, bcs, **kwargs):
rhs_func = rhs.riesz_representation(riesz_map="l2")
for bc in bcs:
bc.apply(rhs_func)
rhs.assign(rhs_func.riesz_representation(riesz_map="l2"))
firedrake.solve(lhs, func, rhs, **kwargs)
return func

def recompute_component(self, inputs, block_variable, idx, prepared):
lhs = prepared[0]
rhs = prepared[1]
func = prepared[2]
bcs = prepared[3]
lhs, rhs, func, bcs = prepared
result = self._forward_solve(lhs, rhs, func, bcs)
if isinstance(block_variable.checkpoint, firedrake.Function):
result = block_variable.checkpoint.assign(result)
Expand Down
27 changes: 10 additions & 17 deletions firedrake/linear_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class LinearSolver(OptionsManager):
@PETSc.Log.EventDecorator()
def __init__(self, A, *, P=None, solver_parameters=None,
nullspace=None, transpose_nullspace=None,
near_nullspace=None, options_prefix=None):
near_nullspace=None, options_prefix=None,
pre_apply_bcs=True):
"""A linear solver for assembled systems (Ax = b).
:arg A: a :class:`~.MatrixBase` (the operator).
Expand All @@ -40,6 +41,7 @@ def __init__(self, A, *, P=None, solver_parameters=None,
created. Use this option if you want to pass options
to the solver from the command line in addition to
through the ``solver_parameters`` dict.
:kwarg pre_apply_bcs: Ignored by this class.
.. note::
Expand Down Expand Up @@ -116,36 +118,27 @@ def _rhs(self):

u = function.Function(self.trial_space)
b = cofunction.Cofunction(self.test_space.dual())
expr = -action(self.A.a, u)
return u, get_assembler(expr).assemble, b
expr = b - action(self.A.a, u)
return u, get_assembler(expr, bcs=self.A.bcs, zero_bc_nodes=False).assemble, b

def _lifted(self, b):
u, update, blift = self._rhs
u.dat.zero()
for bc in self.A.bcs:
bc.apply(u)
blift.assign(b)
update(tensor=blift)
# blift contains -A u_bc
blift += b
if isinstance(blift, cofunction.Cofunction):
blift_func = blift.riesz_representation(riesz_map="l2")
for bc in self.A.bcs:
bc.apply(blift_func)
blift.assign(blift_func.riesz_representation(riesz_map="l2"))
else:
for bc in self.A.bcs:
bc.apply(blift)
# blift is now b - A u_bc, and satisfies the boundary conditions
return blift

@PETSc.Log.EventDecorator()
def solve(self, x, b):
if not isinstance(x, (function.Function, vector.Vector, cofunction.Cofunction)):
raise TypeError("Provided solution is a '%s', not a Function, Vector or Cofunction" % type(x).__name__)
if not isinstance(x, (function.Function, vector.Vector)):
raise TypeError("Provided solution is a '%s', not a Function or Vector" % type(x).__name__)
if isinstance(b, vector.Vector):
b = b.function
if not isinstance(b, (function.Function, cofunction.Cofunction)):
raise TypeError("Provided RHS is a '%s', not a Function or Cofunction" % type(b).__name__)
if not isinstance(b, (cofunction.Cofunction)):
raise TypeError("Provided RHS is a '%s', not a Cofunction" % type(b).__name__)

# When solving `Ax = b`, with A: V x U -> R, or equivalently A: V -> U*,
# we need to make sure that x and b belong to V and U*, respectively.
Expand Down
9 changes: 4 additions & 5 deletions firedrake/slate/static_condensation/hybridization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import functools

import numpy as np
import ufl
import finat.ufl

Expand Down Expand Up @@ -37,7 +36,7 @@ def initialize(self, pc):
A KSP is created for the Lagrange multiplier system.
"""
from firedrake import (FunctionSpace, Cofunction, Function, Constant,
from firedrake import (FunctionSpace, Cofunction, Function,
TrialFunction, TrialFunctions, TestFunction,
DirichletBC)
from firedrake.assemble import get_assembler
Expand Down Expand Up @@ -99,7 +98,7 @@ def initialize(self, pc):
self.unbroken_residual = Function(V)

shapes = (V[self.vidx].finat_element.space_dimension(),
np.prod(V[self.vidx].shape))
V[self.vidx].block_size)
domain = "{[i,j]: 0 <= i < %d and 0 <= j < %d}" % shapes
instructions = """
for i, j
Expand Down Expand Up @@ -178,7 +177,7 @@ def initialize(self, pc):
for measure in measures:
Kform += integrand*measure

trace_bcs = [DirichletBC(TraceSpace, Constant(0.0), subdomain) for subdomain in trace_subdomains]
trace_bcs = [DirichletBC(TraceSpace, 0, subdomain) for subdomain in trace_subdomains]

else:
# No bcs were provided, we assume weak Dirichlet conditions.
Expand All @@ -188,7 +187,7 @@ def initialize(self, pc):
trace_subdomains = ["on_boundary"]
if mesh.cell_set._extruded:
trace_subdomains.extend(["bottom", "top"])
trace_bcs = [DirichletBC(TraceSpace, Constant(0.0), subdomain) for subdomain in trace_subdomains]
trace_bcs = [DirichletBC(TraceSpace, 0, subdomain) for subdomain in trace_subdomains]

# Make a SLATE tensor from Kform
K = Tensor(Kform)
Expand Down
10 changes: 6 additions & 4 deletions firedrake/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def _la_solve(A, x, b, **kwargs):
_la_solve(A, x, b, solver_parameters=parameters_dict)."""

P, bcs, solver_parameters, nullspace, nullspace_T, near_nullspace, \
options_prefix = _extract_linear_solver_args(A, x, b, **kwargs)
options_prefix, pre_apply_bcs = _extract_linear_solver_args(A, x, b, **kwargs)

# Check whether solution is valid
if not isinstance(x, (function.Function, vector.Vector)):
Expand All @@ -260,7 +260,8 @@ def _la_solve(A, x, b, **kwargs):
mat_type=mat_type,
pmat_type=mat_type,
appctx=appctx,
options_prefix=options_prefix)
options_prefix=options_prefix,
pre_apply_bcs=pre_apply_bcs)
dm = solver.ksp.dm

with dmhooks.add_hooks(dm, solver, appctx=ctx):
Expand All @@ -269,7 +270,7 @@ def _la_solve(A, x, b, **kwargs):

def _extract_linear_solver_args(*args, **kwargs):
valid_kwargs = ["P", "bcs", "solver_parameters", "nullspace",
"transpose_nullspace", "near_nullspace", "options_prefix"]
"transpose_nullspace", "near_nullspace", "options_prefix", "pre_apply_bcs"]
if len(args) != 3:
raise RuntimeError("Missing required arguments, expecting solve(A, x, b, **kwargs)")

Expand All @@ -285,8 +286,9 @@ def _extract_linear_solver_args(*args, **kwargs):
nullspace_T = kwargs.get("transpose_nullspace", None)
near_nullspace = kwargs.get("near_nullspace", None)
options_prefix = kwargs.get("options_prefix", None)
pre_apply_bcs = kwargs.get("pre_apply_bcs", None)

return P, bcs, solver_parameters, nullspace, nullspace_T, near_nullspace, options_prefix
return P, bcs, solver_parameters, nullspace, nullspace_T, near_nullspace, options_prefix, pre_apply_bcs


def _extract_args(*args, **kwargs):
Expand Down

0 comments on commit c86e896

Please sign in to comment.