Skip to content

Commit

Permalink
solve: stop supporting Vector
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Feb 7, 2025
1 parent 67cebb1 commit 45fffe5
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 49 deletions.
26 changes: 11 additions & 15 deletions firedrake/linear_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
__all__ = ["LinearSolver"]


class LinearSolver:
class LinearSolver(LinearVariationalSolver):

@PETSc.Log.EventDecorator()
def __init__(self, A, *, P=None, **kwargs):
Expand All @@ -18,7 +18,7 @@ def __init__(self, A, *, P=None, **kwargs):
:arg P: an optional :class:`~.MatrixBase` to construct any
preconditioner from; if none is supplied ``A`` is
used to construct the preconditioner.
:kwarg parameters: (optional) dict of solver parameters.
:kwarg solver_parameters: (optional) dict of solver parameters.
:kwarg nullspace: an optional :class:`~.VectorSpaceBasis` (or
:class:`~.MixedVectorSpaceBasis` spanning the null space
of the operator.
Expand All @@ -31,8 +31,8 @@ def __init__(self, A, *, P=None, **kwargs):
created. Use this option if you want to pass options
to the solver from the command line in addition to
through the ``solver_parameters`` dict.
:kwarg pre_apply_bcs: If `False`, the problem is linearised
around the initial guess before imposing the boundary conditions.
:kwarg pre_apply_bcs: If `True`, the bcs are applied before the solve.
Otherwise, the bcs are included as part of the linear system.
.. note::
Expand All @@ -44,23 +44,19 @@ def __init__(self, A, *, P=None, **kwargs):
if P is not None and not isinstance(P, MatrixBase):
raise TypeError("Provided preconditioner is a '%s', not a MatrixBase" % type(P).__name__)

test, trial = A.a.arguments()
x = Function(trial.function_space())
b = Cofunction(test.function_space().dual())
test, trial = A.arguments()
self.x = Function(trial.function_space())
self.b = Cofunction(test.function_space().dual())

problem = LinearVariationalProblem(A, b, x, bcs=A.bcs, aP=P)
solver = LinearVariationalSolver(problem, **kwargs)
self.b = b
self.x = x
self.solver = solver
problem = LinearVariationalProblem(A, self.b, self.x, bcs=A.bcs, aP=P)
LinearVariationalSolver.__init__(self, problem, **kwargs)

self.A = A
self.comm = A.comm
self._comm = internal_comm(self.comm, self)
self.P = P if P is not None else A

self.ksp = self.solver.snes.ksp
self.parameters = self.solver.parameters
self.ksp = self.snes.ksp

@PETSc.Log.EventDecorator()
def solve(self, x, b):
Expand All @@ -87,5 +83,5 @@ def solve(self, x, b):

self.x.assign(x)
self.b.assign(b)
self.solver.solve()
super().solve()
x.assign(self.x)
26 changes: 2 additions & 24 deletions firedrake/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _solve_varproblem(*args, **kwargs):
options_prefix, restrict, pre_apply_bcs = _extract_args(*args, **kwargs)

# Check whether solution is valid
if not isinstance(u, (function.Function, vector.Vector)):
if not isinstance(u, function.Function):
raise TypeError(f"Provided solution is a '{type(u).__name__}', not a Function")

if form_compiler_parameters is None:
Expand Down Expand Up @@ -238,10 +238,6 @@ def _la_solve(A, x, b, **kwargs):
P, bcs, solver_parameters, nullspace, nullspace_T, near_nullspace, \
options_prefix, pre_apply_bcs = _extract_linear_solver_args(A, x, b, **kwargs)

# Check whether solution is valid
if not isinstance(x, (function.Function, vector.Vector)):
raise TypeError(f"Provided solution is a '{type(x).__name__}', not a Function")

if bcs is not None:
raise RuntimeError("It is no longer possible to apply or change boundary conditions after assembling the matrix `A`; pass any necessary boundary conditions to `assemble` when assembling `A`.")

Expand All @@ -250,25 +246,7 @@ def _la_solve(A, x, b, **kwargs):
transpose_nullspace=nullspace_T,
near_nullspace=near_nullspace,
options_prefix=options_prefix)
if isinstance(x, firedrake.Vector):
x = x.function
# linear MG doesn't need RHS, supply zero.
L = 0
aP = None if P is None else P.a
lvp = vs.LinearVariationalProblem(A.a, L, x, bcs=A.bcs, aP=aP)
mat_type = A.mat_type
pmat_type = mat_type if P is None else P.mat_type
appctx = solver_parameters.get("appctx", {})
ctx = solving_utils._SNESContext(lvp,
mat_type=mat_type,
pmat_type=pmat_type,
appctx=appctx,
options_prefix=options_prefix,
pre_apply_bcs=pre_apply_bcs)
dm = solver.ksp.dm

with dmhooks.add_hooks(dm, solver, appctx=ctx):
solver.solve(x, b)
solver.solve(x, b)


def _extract_linear_solver_args(*args, **kwargs):
Expand Down
7 changes: 5 additions & 2 deletions firedrake/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,9 @@ def __init__(self, problem, *, solver_parameters=None,
before residual assembly.
:kwarg post_function_callback: As above, but called immediately
after residual assembly.
:kwarg pre_apply_bcs: If `False`, the problem is linearised
around the initial guess before imposing the boundary conditions.
:kwarg pre_apply_bcs: If `True`, the bcs are applied before the solve.
Otherwise, the problem is linearised around the initial guess
before imposing bcs, and the bcs are appended to the nonlinear system.
Example usage of the ``solver_parameters`` option: to set the
nonlinear solver type to just use a linear solver, use
Expand Down Expand Up @@ -418,6 +419,8 @@ class LinearVariationalSolver(NonlinearVariationalSolver):
before residual assembly.
:kwarg post_function_callback: As above, but called immediately
after residual assembly.
:kwarg pre_apply_bcs: If `True`, the bcs are applied before the solve.
Otherwise, the bcs are included as part of the linear system.
See also :class:`NonlinearVariationalSolver` for nonlinear problems.
"""
Expand Down
11 changes: 3 additions & 8 deletions tests/firedrake/regression/test_solving_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_linear_solves_equivalent():

f = Function(V)
f.assign(1)
f.vector()[:] = 1.
f.dat.data_wo[:] = 1.
t = TestFunction(V)
q = TrialFunction(V)

Expand All @@ -150,17 +150,12 @@ def test_linear_solves_equivalent():
# And again
sol2 = Function(V)
solve(a == L, sol2)
assert np_norm(sol.vector()[:] - sol2.vector()[:]) == 0
assert np_norm(sol.dat.data_ro - sol2.dat.data_ro) == 0

# Solve the system using preassembled objects
sol3 = Function(V)
solve(assemble(a), sol3, assemble(L))
assert np_norm(sol.vector()[:] - sol3.vector()[:]) < 5e-14

# Same, solving into vector
sol4 = sol3.vector()
solve(assemble(a), sol4, assemble(L))
assert np_norm(sol.vector()[:] - sol4[:]) < 5e-14
assert np_norm(sol.dat.data_ro - sol3.dat.data_ro) < 5e-14


def test_linear_solver_flattens_params(a_L_out):
Expand Down

0 comments on commit 45fffe5

Please sign in to comment.