Skip to content

Commit

Permalink
MatrixBase: attach fc_params
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Feb 7, 2025
1 parent cfe094b commit ba4a055
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 19 deletions.
3 changes: 2 additions & 1 deletion firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,7 +1360,8 @@ def allocate(self):
self._sub_mat_type,
self._make_maps_and_regions())
return matrix.Matrix(self._form, self._bcs, self._mat_type, sparsity, ScalarType,
options_prefix=self._options_prefix)
options_prefix=self._options_prefix,
fc_params=self._form_compiler_params)

@staticmethod
def _make_sparsity(test, trial, mat_type, sub_mat_type, maps_and_regions):
Expand Down
6 changes: 1 addition & 5 deletions firedrake/linear_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ def __init__(self, A, *, P=None, **kwargs):
through the ``solver_parameters`` dict.
:kwarg pre_apply_bcs: If `True`, the bcs are applied before the solve.
Otherwise, the bcs are included as part of the linear system.
:kwarg form_compiler_parameters: (optional) dict of form compiler
parameters, only used to assemble the lifted residual.
.. note::
Expand All @@ -50,11 +48,9 @@ def __init__(self, A, *, P=None, **kwargs):
self.x = Function(trial.function_space())
self.b = Cofunction(test.function_space().dual())

fc_params = kwargs.pop("form_compiler_parameters", None)
problem = LinearVariationalProblem(A, self.b, self.x, bcs=A.bcs, aP=P,
form_compiler_parameters=fc_params,
form_compiler_parameters=A.fc_params,
constant_jacobian=True)

super().__init__(problem, **kwargs)

self.A = A
Expand Down
21 changes: 14 additions & 7 deletions firedrake/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ class MatrixBase(ufl.Matrix):
:class:`MatrixBase`. May be `None` if there are no boundary
conditions to apply.
:arg mat_type: matrix type of assembled matrix, or 'matfree' for matrix-free
:arg fc_params: a dict of form compiler parameters of this matrix
"""
def __init__(self, a, bcs, mat_type):
def __init__(self, a, bcs, mat_type, fc_params):
if isinstance(a, tuple):
self.a = None
test, trial = a
Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(self, a, bcs, mat_type):
Matrix type used in the assembly of the PETSc matrix: 'aij', 'baij', 'dense' or 'nest',
or 'matfree' for matrix-free."""
self.fc_params = fc_params

def arguments(self):
if self.a:
Expand Down Expand Up @@ -109,6 +111,8 @@ class Matrix(MatrixBase):
:arg mat_type: matrix type of assembled matrix.
:kwarg fc_params: a dict of form compiler parameters for this matrix.
A ``pyop2.types.mat.Mat`` will be built from the remaining
arguments, for valid values, see ``pyop2.types.mat.Mat`` source code.
Expand All @@ -121,8 +125,9 @@ class Matrix(MatrixBase):
"""

def __init__(self, a, bcs, mat_type, *args, **kwargs):
# sets self._a, self._bcs, and self._mat_type
MatrixBase.__init__(self, a, bcs, mat_type)
# sets self.a, self.bcs, self.mat_type, and self.fc_params
fc_params = kwargs.pop("fc_params", None)
MatrixBase.__init__(self, a, bcs, mat_type, fc_params)
options_prefix = kwargs.pop("options_prefix")
self.M = op2.Mat(*args, mat_type=mat_type, **kwargs)
self.petscmat = self.M.handle
Expand All @@ -146,6 +151,7 @@ class ImplicitMatrix(MatrixBase):
:class:`Matrix`. May be `None` if there are no boundary
conditions to apply.
:kwarg fc_params: a dict of form compiler parameters for this matrix.
.. note::
Expand All @@ -155,8 +161,9 @@ class ImplicitMatrix(MatrixBase):
"""
def __init__(self, a, bcs, *args, **kwargs):
# sets self._a, self._bcs, and self._mat_type
super(ImplicitMatrix, self).__init__(a, bcs, "matfree")
# sets self.a, self.bcs, self.mat_type, and self.fc_params
fc_params = kwargs["fc_params"]
super(ImplicitMatrix, self).__init__(a, bcs, "matfree", fc_params)

options_prefix = kwargs.pop("options_prefix")
appctx = kwargs.get("appctx", {})
Expand All @@ -165,7 +172,7 @@ def __init__(self, a, bcs, *args, **kwargs):
ctx = ImplicitMatrixContext(a,
row_bcs=self.bcs,
col_bcs=self.bcs,
fc_params=kwargs["fc_params"],
fc_params=fc_params,
appctx=appctx)
self.petscmat = PETSc.Mat().create(comm=self._comm)
self.petscmat.setType("python")
Expand Down Expand Up @@ -196,7 +203,7 @@ class AssembledMatrix(MatrixBase):
:arg petscmat: the already constructed petsc matrix this object represents.
"""
def __init__(self, a, bcs, petscmat, *args, **kwargs):
super(AssembledMatrix, self).__init__(a, bcs, "assembled")
super(AssembledMatrix, self).__init__(a, bcs, "assembled", None)

self.petscmat = petscmat

Expand Down
10 changes: 4 additions & 6 deletions firedrake/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _la_solve(A, x, b, **kwargs):
_la_solve(A, x, b, solver_parameters=parameters_dict)."""

(P, bcs, solver_parameters, nullspace, nullspace_T, near_nullspace,
options_prefix, pre_apply_bcs, form_compiler_parameters,
options_prefix, pre_apply_bcs,
) = _extract_linear_solver_args(A, x, b, **kwargs)

if bcs is not None:
Expand All @@ -246,15 +246,14 @@ def _la_solve(A, x, b, **kwargs):
transpose_nullspace=nullspace_T,
near_nullspace=near_nullspace,
options_prefix=options_prefix,
pre_apply_bcs=pre_apply_bcs,
form_compiler_parameters=form_compiler_parameters)
pre_apply_bcs=pre_apply_bcs)
solver.solve(x, b)


def _extract_linear_solver_args(*args, **kwargs):
valid_kwargs = ["P", "bcs", "solver_parameters", "nullspace",
"transpose_nullspace", "near_nullspace", "options_prefix",
"pre_apply_bcs", "form_compiler_parameters"]
"pre_apply_bcs"]
if len(args) != 3:
raise RuntimeError("Missing required arguments, expecting solve(A, x, b, **kwargs)")

Expand All @@ -271,9 +270,8 @@ def _extract_linear_solver_args(*args, **kwargs):
near_nullspace = kwargs.get("near_nullspace", None)
options_prefix = kwargs.get("options_prefix", None)
pre_apply_bcs = kwargs.get("pre_apply_bcs", True)
form_compiler_parameters = kwargs.get("form_compiler_parameters", {})

return P, bcs, solver_parameters, nullspace, nullspace_T, near_nullspace, options_prefix, pre_apply_bcs, form_compiler_parameters
return P, bcs, solver_parameters, nullspace, nullspace_T, near_nullspace, options_prefix, pre_apply_bcs


def _extract_args(*args, **kwargs):
Expand Down

0 comments on commit ba4a055

Please sign in to comment.