From a86614a83da9f7353b175b7b93b1efe74b4660eb Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 11 Dec 2024 18:22:10 +0000 Subject: [PATCH 01/31] Restricted Cofunction RHS --- firedrake/assemble.py | 2 +- firedrake/functionspaceimpl.py | 4 ++++ firedrake/variational_solver.py | 8 ++++++-- .../regression/test_restricted_function_space.py | 9 ++++++--- 4 files changed, 17 insertions(+), 6 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index f451b3f596..6b1bb7d3e4 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -1201,7 +1201,7 @@ def _apply_dirichlet_bc(self, tensor, bc): bc.zero(tensor) def _check_tensor(self, tensor): - if tensor.function_space() != self._form.arguments()[0].function_space(): + if tensor.function_space() != self._form.arguments()[0].function_space().dual(): raise ValueError("Form's argument does not match provided result tensor") @staticmethod diff --git a/firedrake/functionspaceimpl.py b/firedrake/functionspaceimpl.py index 8fc81244f7..7e7ecfbcbf 100644 --- a/firedrake/functionspaceimpl.py +++ b/firedrake/functionspaceimpl.py @@ -13,6 +13,7 @@ import ufl import finat.ufl +from ufl.duals import is_dual, is_primal from pyop2 import op2, mpi from pyop2.utils import as_tuple @@ -296,6 +297,9 @@ def restore_work_function(self, function): cache[function] = False def __eq__(self, other): + if is_primal(self) != is_primal(other) or \ + is_dual(self) != is_dual(other): + return False try: return self.topological == other.topological and \ self.mesh() is other.mesh() diff --git a/firedrake/variational_solver.py b/firedrake/variational_solver.py index 4a1ac396c5..488b42d55c 100644 --- a/firedrake/variational_solver.py +++ b/firedrake/variational_solver.py @@ -9,7 +9,7 @@ PETSc, OptionsManager, flatten_parameters, DEFAULT_KSP_PARAMETERS, DEFAULT_SNES_PARAMETERS ) -from firedrake.function import Function +from firedrake.function import Function, Cofunction from firedrake.ufl_expr import TrialFunction, TestFunction from firedrake.bcs import DirichletBC, EquationBC, extract_subdomain_ids, restricted_function_space from firedrake.adjoint_utils import NonlinearVariationalProblemMixin, NonlinearVariationalSolverMixin @@ -92,7 +92,11 @@ def __init__(self, F, u, bcs=None, J=None, self.u_restrict = Function(V_res).interpolate(u) v_res, u_res = TestFunction(V_res), TrialFunction(V_res) F_arg, = F.arguments() - self.F = replace(F, {F_arg: v_res, self.u: self.u_restrict}) + replace_F = {F_arg: v_res, self.u: self.u_restrict} + for c in F.coefficients(): + if c.function_space() == V.dual(): + replace_F[c] = Cofunction(V_res.dual()).interpolate(c) + self.F = replace(F, replace_F) v_arg, u_arg = self.J.arguments() self.J = replace(self.J, {v_arg: v_res, u_arg: u_res, self.u: self.u_restrict}) if self.Jp: diff --git a/tests/firedrake/regression/test_restricted_function_space.py b/tests/firedrake/regression/test_restricted_function_space.py index dc9a2ecc64..6dd093f58a 100644 --- a/tests/firedrake/regression/test_restricted_function_space.py +++ b/tests/firedrake/regression/test_restricted_function_space.py @@ -146,7 +146,8 @@ def test_poisson_inhomogeneous_bcs_2(j): @pytest.mark.parallel(nprocs=3) -def test_poisson_inhomogeneous_bcs_high_level_interface(): +@pytest.mark.parametrize("assembled_rhs", [False, True], ids=("Form", "Cofunction")) +def test_poisson_inhomogeneous_bcs_high_level_interface(assembled_rhs): mesh = UnitSquareMesh(8, 8) V = FunctionSpace(mesh, "CG", 2) bc1 = DirichletBC(V, 0., 1) @@ -155,9 +156,11 @@ def test_poisson_inhomogeneous_bcs_high_level_interface(): v = TestFunction(V) a = inner(grad(u), grad(v)) * dx u = Function(V) - L = inner(Constant(0), v) * dx + L = inner(Constant(-2), v) * dx + if assembled_rhs: + L = assemble(L) solve(a == L, u, bcs=[bc1, bc2], restrict=True) - assert errornorm(SpatialCoordinate(mesh)[0], u) < 1.e-12 + assert errornorm(SpatialCoordinate(mesh)[0]**2, u) < 1.e-12 @pytest.mark.parametrize("j", [1, 2, 5]) From aef788674fd7d405127eff7245b61ec67bd29bfe Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 11 Dec 2024 21:16:30 +0000 Subject: [PATCH 02/31] Fix BCs on Cofunction --- demos/netgen/netgen_mesh.py.rst | 3 +-- .../adjoint_utils/blocks/dirichlet_bc.py | 6 +++--- firedrake/adjoint_utils/blocks/solving.py | 19 +++++++------------ tests/firedrake/regression/test_netgen.py | 6 ++---- 4 files changed, 13 insertions(+), 21 deletions(-) diff --git a/demos/netgen/netgen_mesh.py.rst b/demos/netgen/netgen_mesh.py.rst index 47fe769e70..3c1bd7301b 100755 --- a/demos/netgen/netgen_mesh.py.rst +++ b/demos/netgen/netgen_mesh.py.rst @@ -380,8 +380,7 @@ We will now show how to solve the Poisson problem on a high-order mesh, of order bc = DirichletBC(V, 0.0, [1]) A = assemble(a, bcs=bc) - b = assemble(l) - bc.apply(b) + b = assemble(l, bcs=bc, zero_bc_nodes=True) solve(A, sol, b, solver_parameters={"ksp_type": "cg", "pc_type": "lu"}) VTKFile("output/Sphere.pvd").write(sol) diff --git a/firedrake/adjoint_utils/blocks/dirichlet_bc.py b/firedrake/adjoint_utils/blocks/dirichlet_bc.py index b06d367da1..a918fb8a9e 100644 --- a/firedrake/adjoint_utils/blocks/dirichlet_bc.py +++ b/firedrake/adjoint_utils/blocks/dirichlet_bc.py @@ -51,7 +51,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, adj_output = None for adj_input in adj_inputs: if isconstant(c): - adj_value = firedrake.Function(self.parent_space.dual()) + adj_value = firedrake.Function(self.parent_space) adj_input.apply(adj_value) if self.function_space != self.parent_space: vec = extract_bc_subvector( @@ -88,11 +88,11 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, # you can even use the Function outside its domain. # For now we will just assume the FunctionSpace is the same for # the BC and the Function. - adj_value = firedrake.Function(self.parent_space.dual()) + adj_value = firedrake.Function(self.parent_space) adj_input.apply(adj_value) r = extract_bc_subvector( adj_value, c.function_space(), bc - ) + ).riesz_representation("l2") if adj_output is None: adj_output = r else: diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index e4664665b0..0e6adce8df 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -197,14 +197,12 @@ def _assemble_dFdu_adj(self, dFdu_adj_form, **kwargs): def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy): dJdu_copy = dJdu.copy() - kwargs = self.assemble_kwargs.copy() # Homogenize and apply boundary conditions on adj_dFdu and dJdu. bcs = self._homogenize_bcs() - kwargs["bcs"] = bcs - dFdu = self._assemble_dFdu_adj(dFdu_adj_form, **kwargs) + dFdu = firedrake.assemble(dFdu_adj_form, bcs=bcs, **self.assemble_kwargs) for bc in bcs: - bc.apply(dJdu) + bc.zero(dJdu) adj_sol = firedrake.Function(self.function_space) firedrake.solve( @@ -219,10 +217,8 @@ def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy): return adj_sol, adj_sol_bdy def _compute_adj_bdy(self, adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu): - adj_sol_bdy = firedrake.Function( - self.function_space.dual(), dJdu.dat - firedrake.assemble( - firedrake.action(dFdu_adj_form, adj_sol)).dat) - return adj_sol_bdy + adj_sol_bdy = firedrake.assemble(dJdu - firedrake.action(dFdu_adj_form, adj_sol)) + return adj_sol_bdy.riesz_representation("l2") def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): @@ -654,9 +650,8 @@ def _forward_solve(self, lhs, rhs, func, bcs, **kwargs): def _adjoint_solve(self, dJdu, compute_bdy): dJdu_copy = dJdu.copy() # Homogenize and apply boundary conditions on adj_dFdu and dJdu. - bcs = self._homogenize_bcs() - for bc in bcs: - bc.apply(dJdu) + for bc in self.bcs: + bc.zero(dJdu) if ( self._ad_solvers["forward_nlvs"]._problem._constant_jacobian @@ -876,7 +871,7 @@ def __init__(self, source, target_space, target, bcs=[], **kwargs): self.add_dependency(bc, no_duplicates=True) def apply_mixedmass(self, a): - b = firedrake.Function(self.target_space) + b = firedrake.Function(self.target_space.dual()) with a.dat.vec_ro as vsrc, b.dat.vec_wo as vrhs: self.mixed_mass.mult(vsrc, vrhs) return b diff --git a/tests/firedrake/regression/test_netgen.py b/tests/firedrake/regression/test_netgen.py index 904e9a6819..9f401f1675 100644 --- a/tests/firedrake/regression/test_netgen.py +++ b/tests/firedrake/regression/test_netgen.py @@ -51,8 +51,7 @@ def poisson(h, degree=2): # Assembling matrix A = assemble(a, bcs=bc) - b = assemble(l) - bc.apply(b) + b = assemble(l, bcs=bc, zero_bc_nodes=True) # Solving the problem solve(A, u, b, solver_parameters={"ksp_type": "preonly", "pc_type": "lu"}) @@ -95,8 +94,7 @@ def poisson3D(h, degree=2): # Assembling matrix A = assemble(a, bcs=bc) - b = assemble(l) - bc.apply(b) + b = assemble(l, bcs=bc, zero_bc_nodes=True) # Solving the problem solve(A, u, b, solver_parameters={"ksp_type": "preonly", "pc_type": "lu"}) From 8e5603eb0f068c79b404241fdb38a4cf0dde428f Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 11 Dec 2024 23:52:40 +0000 Subject: [PATCH 03/31] LinearSolver: check function spaces --- firedrake/linear_solver.py | 5 +++++ tests/firedrake/regression/test_assemble_baseform.py | 1 + 2 files changed, 6 insertions(+) diff --git a/firedrake/linear_solver.py b/firedrake/linear_solver.py index c1dfbcc07e..7a9f7d807d 100644 --- a/firedrake/linear_solver.py +++ b/firedrake/linear_solver.py @@ -147,6 +147,11 @@ def solve(self, x, b): if not isinstance(b, (function.Function, cofunction.Cofunction)): raise TypeError("Provided RHS is a '%s', not a Function or Cofunction" % type(b).__name__) + if x.function_space() != self.trial_space or b.function_space() != self.test_space.dual(): + # When solving `Ax = b`, with A: V x U -> R, or equivalently A: V -> U*, + # we need to make sure that x and b belong to V and U*, respectively. + raise ValueError("Mismatching function spaces.") + if len(self.trial_space) > 1 and self.nullspace is not None: self.nullspace._apply(self.trial_space.dof_dset.field_ises) if len(self.test_space) > 1 and self.transpose_nullspace is not None: diff --git a/tests/firedrake/regression/test_assemble_baseform.py b/tests/firedrake/regression/test_assemble_baseform.py index 063a33bdd9..219549faa3 100644 --- a/tests/firedrake/regression/test_assemble_baseform.py +++ b/tests/firedrake/regression/test_assemble_baseform.py @@ -155,6 +155,7 @@ def test_zero_form(M, f, one): assert abs(zero_form - 0.5 * np.prod(f.ufl_shape)) < 1.0e-12 +@pytest.mark.xfail(reason="action(M, M) causes primal-dual error") def test_preprocess_form(M, a, f): from ufl.algorithms import expand_indices, expand_derivatives From 2dd4f762c5f06cb3a3d17ac44ce5d8731dfea154 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 12 Dec 2024 12:25:52 +0000 Subject: [PATCH 04/31] assemble(form, zero_bc_nodes=True) as default --- demos/netgen/netgen_mesh.py.rst | 2 +- firedrake/assemble.py | 18 ++++++++---------- firedrake/matrix_free/operators.py | 2 +- firedrake/slate/static_condensation/scpc.py | 2 +- firedrake/solving_utils.py | 2 +- tests/firedrake/multigrid/test_poisson_gmg.py | 7 +++---- tests/firedrake/regression/test_assemble.py | 2 +- .../regression/test_assemble_baseform.py | 2 +- tests/firedrake/regression/test_netgen.py | 4 ++-- 9 files changed, 19 insertions(+), 22 deletions(-) diff --git a/demos/netgen/netgen_mesh.py.rst b/demos/netgen/netgen_mesh.py.rst index 3c1bd7301b..3947cc828b 100755 --- a/demos/netgen/netgen_mesh.py.rst +++ b/demos/netgen/netgen_mesh.py.rst @@ -380,7 +380,7 @@ We will now show how to solve the Poisson problem on a high-order mesh, of order bc = DirichletBC(V, 0.0, [1]) A = assemble(a, bcs=bc) - b = assemble(l, bcs=bc, zero_bc_nodes=True) + b = assemble(l, bcs=bc) solve(A, sol, b, solver_parameters={"ksp_type": "cg", "pc_type": "lu"}) VTKFile("output/Sphere.pvd").write(sol) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 6b1bb7d3e4..db38430eb3 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -83,7 +83,7 @@ def assemble(expr, *args, **kwargs): zero_bc_nodes : bool If `True`, set the boundary condition nodes in the output tensor to zero rather than to the values prescribed by the - boundary condition. Default is `False`. + boundary condition. Default is `True`. diagonal : bool If assembling a matrix is it diagonal? weight : float @@ -156,7 +156,7 @@ def get_assembler(form, *args, **kwargs): return ZeroFormAssembler(form, form_compiler_parameters=fc_params) elif len(form.arguments()) == 1 or diagonal: return OneFormAssembler(form, *args, bcs=bcs, form_compiler_parameters=fc_params, needs_zeroing=kwargs.get('needs_zeroing', True), - zero_bc_nodes=kwargs.get('zero_bc_nodes', False), diagonal=diagonal) + zero_bc_nodes=kwargs.get('zero_bc_nodes', True), diagonal=diagonal) elif len(form.arguments()) == 2: return TwoFormAssembler(form, *args, **kwargs) else: @@ -308,7 +308,7 @@ def __init__(self, sub_mat_type=None, options_prefix=None, appctx=None, - zero_bc_nodes=False, + zero_bc_nodes=True, diagonal=False, weight=1.0, allocation_integral_types=None): @@ -1190,13 +1190,11 @@ def _apply_bc(self, tensor, bc): raise AssertionError def _apply_dirichlet_bc(self, tensor, bc): - if not self._zero_bc_nodes: - tensor_func = tensor.riesz_representation(riesz_map="l2") - if self._diagonal: - bc.set(tensor_func, 1) - else: - bc.apply(tensor_func) - tensor.assign(tensor_func.riesz_representation(riesz_map="l2")) + if self._diagonal: + bc.set(tensor, self._weight) + elif not self._zero_bc_nodes: + # We cannot set primal data on a dual Cofunction, this will throw an error + bc.apply(tensor) else: bc.zero(tensor) diff --git a/firedrake/matrix_free/operators.py b/firedrake/matrix_free/operators.py index 3ee448730e..e31625354a 100644 --- a/firedrake/matrix_free/operators.py +++ b/firedrake/matrix_free/operators.py @@ -147,7 +147,7 @@ def __init__(self, a, row_bcs=[], col_bcs=[], self._assemble_action = get_assembler(self.action, bcs=self.bcs_action, form_compiler_parameters=self.fc_params, - zero_bc_nodes=True).assemble + ).assemble # For assembling action(adjoint(f), self._y) # Sorted list of equation bcs diff --git a/firedrake/slate/static_condensation/scpc.py b/firedrake/slate/static_condensation/scpc.py index 35fa4742eb..c4bda6dc27 100644 --- a/firedrake/slate/static_condensation/scpc.py +++ b/firedrake/slate/static_condensation/scpc.py @@ -102,7 +102,7 @@ def initialize(self, pc): r_expr = reduced_sys.rhs # Construct the condensed right-hand side - self._assemble_Srhs = get_assembler(r_expr, bcs=bcs, zero_bc_nodes=True, form_compiler_parameters=self.cxt.fc_params).assemble + self._assemble_Srhs = get_assembler(r_expr, bcs=bcs, form_compiler_parameters=self.cxt.fc_params).assemble # Allocate and set the condensed operator form_assembler = get_assembler(S_expr, bcs=bcs, form_compiler_parameters=self.cxt.fc_params, mat_type=mat_type, options_prefix=prefix, appctx=self.get_appctx(pc)) diff --git a/firedrake/solving_utils.py b/firedrake/solving_utils.py index 9e843016b5..1da0c0f836 100644 --- a/firedrake/solving_utils.py +++ b/firedrake/solving_utils.py @@ -221,7 +221,7 @@ def __init__(self, problem, mat_type, pmat_type, appctx=None, self._assemble_residual = get_assembler(self.F, bcs=self.bcs_F, form_compiler_parameters=self.fcp, - zero_bc_nodes=True).assemble + ).assemble self._jacobian_assembled = False self._splits = {} diff --git a/tests/firedrake/multigrid/test_poisson_gmg.py b/tests/firedrake/multigrid/test_poisson_gmg.py index 81f56acbfc..577587a393 100644 --- a/tests/firedrake/multigrid/test_poisson_gmg.py +++ b/tests/firedrake/multigrid/test_poisson_gmg.py @@ -195,12 +195,11 @@ def test_baseform_coarsening(solver_type, mixed): a_terms.append(inner(grad(u), grad(v)) * dx) a = sum(a_terms) - assemble_bcs = lambda L: assemble(L, bcs=bcs, zero_bc_nodes=True) # These are equivalent right-hand sides sources = [sum(forms), # purely symbolic linear form - assemble_bcs(sum(forms)), # purely numerical cofunction - sum(assemble_bcs(form) for form in forms), # symbolic combination of numerical cofunctions - forms[0] + assemble_bcs(sum(forms[1:])), # symbolic plus numerical + assemble(sum(forms), bcs=bcs), # purely numerical cofunction + sum(assemble(form, bcs=bcs) for form in forms), # symbolic combination of numerical cofunctions + forms[0] + assemble(sum(forms[1:]), bcs=bcs), # symbolic plus numerical ] solutions = [] for L in sources: diff --git a/tests/firedrake/regression/test_assemble.py b/tests/firedrake/regression/test_assemble.py index bd8f020e60..72394380cb 100644 --- a/tests/firedrake/regression/test_assemble.py +++ b/tests/firedrake/regression/test_assemble.py @@ -225,7 +225,7 @@ def test_one_form_assembler_cache(mesh): assert len(L._cache[_FORM_CACHE_KEY]) == 3 # changing zero_bc_nodes should increase the cache size - assemble(L, zero_bc_nodes=True) + assemble(L, zero_bc_nodes=False) assert len(L._cache[_FORM_CACHE_KEY]) == 4 diff --git a/tests/firedrake/regression/test_assemble_baseform.py b/tests/firedrake/regression/test_assemble_baseform.py index 219549faa3..c6c528aaaa 100644 --- a/tests/firedrake/regression/test_assemble_baseform.py +++ b/tests/firedrake/regression/test_assemble_baseform.py @@ -155,7 +155,7 @@ def test_zero_form(M, f, one): assert abs(zero_form - 0.5 * np.prod(f.ufl_shape)) < 1.0e-12 -@pytest.mark.xfail(reason="action(M, M) causes primal-dual error") +@pytest.mark.xfail(reason="action(M, M) raises primal-dual TypeError") def test_preprocess_form(M, a, f): from ufl.algorithms import expand_indices, expand_derivatives diff --git a/tests/firedrake/regression/test_netgen.py b/tests/firedrake/regression/test_netgen.py index 9f401f1675..8cc164d98c 100644 --- a/tests/firedrake/regression/test_netgen.py +++ b/tests/firedrake/regression/test_netgen.py @@ -51,7 +51,7 @@ def poisson(h, degree=2): # Assembling matrix A = assemble(a, bcs=bc) - b = assemble(l, bcs=bc, zero_bc_nodes=True) + b = assemble(l, bcs=bc) # Solving the problem solve(A, u, b, solver_parameters={"ksp_type": "preonly", "pc_type": "lu"}) @@ -94,7 +94,7 @@ def poisson3D(h, degree=2): # Assembling matrix A = assemble(a, bcs=bc) - b = assemble(l, bcs=bc, zero_bc_nodes=True) + b = assemble(l, bcs=bc) # Solving the problem solve(A, u, b, solver_parameters={"ksp_type": "preonly", "pc_type": "lu"}) From 3c5e64fb2f400870e22694248fe2b997a7398123 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 12 Dec 2024 18:35:48 +0000 Subject: [PATCH 05/31] Fix FunctionAssignBlock --- firedrake/adjoint_utils/blocks/function.py | 1 + 1 file changed, 1 insertion(+) diff --git a/firedrake/adjoint_utils/blocks/function.py b/firedrake/adjoint_utils/blocks/function.py index e31a0c4567..aa035bcf73 100644 --- a/firedrake/adjoint_utils/blocks/function.py +++ b/firedrake/adjoint_utils/blocks/function.py @@ -79,6 +79,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, ) diff_expr_assembled = firedrake.Function(adj_input_func.function_space()) diff_expr_assembled.interpolate(ufl.conj(diff_expr)) + diff_expr_assembled = diff_expr_assembled.riesz_representation(riesz_map="l2") adj_output = firedrake.Function( R, val=firedrake.assemble(ufl.Action(diff_expr_assembled, adj_input_func)) ) From e3449f5ccdf90f380711405bd8f5a0b44f323a2a Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 12 Dec 2024 21:00:05 +0000 Subject: [PATCH 06/31] Allow Cofunction.assign take in constants --- firedrake/assemble.py | 9 +++++---- firedrake/cofunction.py | 8 ++++++-- tests/firedrake/regression/test_bcs.py | 2 +- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index db38430eb3..59ff6b750e 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -406,7 +406,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args): assembler = ZeroFormAssembler(form, form_compiler_parameters=self._form_compiler_params) elif rank == 1 or (rank == 2 and self._diagonal): assembler = OneFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params, - zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal) + zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight) elif rank == 2: assembler = TwoFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params, mat_type=self._mat_type, sub_mat_type=self._sub_mat_type, @@ -1149,14 +1149,15 @@ class OneFormAssembler(ParloopFormAssembler): @classmethod def _cache_key(cls, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True, - zero_bc_nodes=False, diagonal=False): + zero_bc_nodes=False, diagonal=False, weight=1.0): bcs = solving._extract_bcs(bcs) return tuple(bcs), tuplify(form_compiler_parameters), needs_zeroing, zero_bc_nodes, diagonal @FormAssembler._skip_if_initialised def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True, - zero_bc_nodes=False, diagonal=False): + zero_bc_nodes=False, diagonal=False, weight=1.0): super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters, needs_zeroing=needs_zeroing) + self._weight = weight self._diagonal = diagonal self._zero_bc_nodes = zero_bc_nodes if self._diagonal and any(isinstance(bc, EquationBCSplit) for bc in self._bcs): @@ -1185,7 +1186,7 @@ def _apply_bc(self, tensor, bc): elif isinstance(bc, EquationBCSplit): bc.zero(tensor) type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, - zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) + zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight).assemble(tensor=tensor) else: raise AssertionError diff --git a/firedrake/cofunction.py b/firedrake/cofunction.py index 4878e6da59..d4b16c0728 100644 --- a/firedrake/cofunction.py +++ b/firedrake/cofunction.py @@ -225,8 +225,12 @@ def assign(self, expr, subset=None, expr_from_assemble=False): return self.assign( assembled_expr, subset=subset, expr_from_assemble=True) - - raise ValueError('Cannot assign %s' % expr) + elif expr == 0: + self.dat.zero(subset=subset) + else: + from firedrake.assign import Assigner + Assigner(self, expr, subset).assign() + return self def riesz_representation(self, riesz_map='L2', **solver_options): """Return the Riesz representation of this :class:`Cofunction` with respect to the given Riesz map. diff --git a/tests/firedrake/regression/test_bcs.py b/tests/firedrake/regression/test_bcs.py index 0e27515890..16cbc669c6 100644 --- a/tests/firedrake/regression/test_bcs.py +++ b/tests/firedrake/regression/test_bcs.py @@ -327,7 +327,7 @@ def test_bcs_rhs_assemble(a, V): b1 = assemble(a) b1_func = b1.riesz_representation(riesz_map="l2") for bc in bcs: - bc.apply(b1_func) + bc.zero(b1_func) b1.assign(b1_func.riesz_representation(riesz_map="l2")) b2 = assemble(a, bcs=bcs) assert np.allclose(b1.dat.data, b2.dat.data) From 40cf6d706653f015824eb5c610d6039dbe889bbc Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 13 Dec 2024 16:59:49 +0000 Subject: [PATCH 07/31] suggestion from code review --- firedrake/assemble.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 59ff6b750e..347f2aca90 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -143,20 +143,18 @@ def get_assembler(form, *args, **kwargs): """ is_base_form_preprocessed = kwargs.pop('is_base_form_preprocessed', False) - bcs = kwargs.get('bcs', None) - fc_params = kwargs.get('form_compiler_parameters', None) if isinstance(form, ufl.form.BaseForm) and not is_base_form_preprocessed: mat_type = kwargs.get('mat_type', None) + fc_params = kwargs.get('form_compiler_parameters', None) # Preprocess the DAG and restructure the DAG # Only pre-process `form` once beforehand to avoid pre-processing for each assembly call form = BaseFormAssembler.preprocess_base_form(form, mat_type=mat_type, form_compiler_parameters=fc_params) if isinstance(form, (ufl.form.Form, slate.TensorBase)) and not BaseFormAssembler.base_form_operands(form): diagonal = kwargs.pop('diagonal', False) if len(form.arguments()) == 0: - return ZeroFormAssembler(form, form_compiler_parameters=fc_params) + return ZeroFormAssembler(form, **kwargs) elif len(form.arguments()) == 1 or diagonal: - return OneFormAssembler(form, *args, bcs=bcs, form_compiler_parameters=fc_params, needs_zeroing=kwargs.get('needs_zeroing', True), - zero_bc_nodes=kwargs.get('zero_bc_nodes', True), diagonal=diagonal) + return OneFormAssembler(form, *args, diagonal=diagonal, **kwargs) elif len(form.arguments()) == 2: return TwoFormAssembler(form, *args, **kwargs) else: @@ -1149,13 +1147,13 @@ class OneFormAssembler(ParloopFormAssembler): @classmethod def _cache_key(cls, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True, - zero_bc_nodes=False, diagonal=False, weight=1.0): + zero_bc_nodes=True, diagonal=False, weight=1.0): bcs = solving._extract_bcs(bcs) - return tuple(bcs), tuplify(form_compiler_parameters), needs_zeroing, zero_bc_nodes, diagonal + return tuple(bcs), tuplify(form_compiler_parameters), needs_zeroing, zero_bc_nodes, diagonal, weight @FormAssembler._skip_if_initialised def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True, - zero_bc_nodes=False, diagonal=False, weight=1.0): + zero_bc_nodes=True, diagonal=False, weight=1.0): super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters, needs_zeroing=needs_zeroing) self._weight = weight self._diagonal = diagonal From fe30b4820090b7f22527d1c9f88864ab1dad2547 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 19 Dec 2024 11:17:09 -0600 Subject: [PATCH 08/31] more suggestions from review --- firedrake/assemble.py | 2 +- firedrake/linear_solver.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 347f2aca90..875d27862e 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -1192,7 +1192,7 @@ def _apply_dirichlet_bc(self, tensor, bc): if self._diagonal: bc.set(tensor, self._weight) elif not self._zero_bc_nodes: - # We cannot set primal data on a dual Cofunction, this will throw an error + # NOTE this will only work if tensor is a Function and not a Cofunction bc.apply(tensor) else: bc.zero(tensor) diff --git a/firedrake/linear_solver.py b/firedrake/linear_solver.py index 7a9f7d807d..7721675a57 100644 --- a/firedrake/linear_solver.py +++ b/firedrake/linear_solver.py @@ -147,10 +147,12 @@ def solve(self, x, b): if not isinstance(b, (function.Function, cofunction.Cofunction)): raise TypeError("Provided RHS is a '%s', not a Function or Cofunction" % type(b).__name__) - if x.function_space() != self.trial_space or b.function_space() != self.test_space.dual(): - # When solving `Ax = b`, with A: V x U -> R, or equivalently A: V -> U*, - # we need to make sure that x and b belong to V and U*, respectively. - raise ValueError("Mismatching function spaces.") + # When solving `Ax = b`, with A: V x U -> R, or equivalently A: V -> U*, + # we need to make sure that x and b belong to V and U*, respectively. + if x.function_space() != self.trial_space: + raise ValueError(f"x must be a Function in {self.trial_space}.") + if b.function_space() != self.test_space.dual(): + raise ValueError(f"b must be a Cofunction in {self.test_space.dual()}.") if len(self.trial_space) > 1 and self.nullspace is not None: self.nullspace._apply(self.trial_space.dof_dset.field_ises) From 3d49f31ec5967838dd1215c88d0ebedaf9f29c6a Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 19 Dec 2024 14:51:26 -0600 Subject: [PATCH 09/31] remove BaseFormAssembler test --- firedrake/cofunction.py | 2 - .../regression/test_assemble_baseform.py | 48 ++----------------- 2 files changed, 3 insertions(+), 47 deletions(-) diff --git a/firedrake/cofunction.py b/firedrake/cofunction.py index d4b16c0728..436d0b7556 100644 --- a/firedrake/cofunction.py +++ b/firedrake/cofunction.py @@ -225,8 +225,6 @@ def assign(self, expr, subset=None, expr_from_assemble=False): return self.assign( assembled_expr, subset=subset, expr_from_assemble=True) - elif expr == 0: - self.dat.zero(subset=subset) else: from firedrake.assign import Assigner Assigner(self, expr, subset).assign() diff --git a/tests/firedrake/regression/test_assemble_baseform.py b/tests/firedrake/regression/test_assemble_baseform.py index c6c528aaaa..54cc0b183c 100644 --- a/tests/firedrake/regression/test_assemble_baseform.py +++ b/tests/firedrake/regression/test_assemble_baseform.py @@ -1,7 +1,7 @@ import pytest import numpy as np from firedrake import * -from firedrake.assemble import BaseFormAssembler, get_assembler +from firedrake.assemble import get_assembler from firedrake.utils import ScalarType import ufl @@ -43,39 +43,15 @@ def fs(request, mesh): @pytest.fixture def f(fs): f = Function(fs, name="f") - f_split = f.subfunctions x = SpatialCoordinate(fs.mesh())[0] - - # NOTE: interpolation of UFL expressions into mixed - # function spaces is not yet implemented - for fi in f_split: - fs_i = fi.function_space() - if fs_i.rank == 1: - fi.interpolate(as_vector((x,) * fs_i.value_size)) - elif fs_i.rank == 2: - fi.interpolate(as_tensor([[x for i in range(fs_i.mesh().geometric_dimension())] - for j in range(fs_i.rank)])) - else: - fi.interpolate(x) + f.interpolate(as_tensor(np.full(f.ufl_shape, x))) return f @pytest.fixture def one(fs): one = Function(fs, name="one") - ones = one.subfunctions - - # NOTE: interpolation of UFL expressions into mixed - # function spaces is not yet implemented - for fi in ones: - fs_i = fi.function_space() - if fs_i.rank == 1: - fi.interpolate(Constant((1.0,) * fs_i.value_size)) - elif fs_i.rank == 2: - fi.interpolate(Constant([[1.0 for i in range(fs_i.mesh().geometric_dimension())] - for j in range(fs_i.rank)])) - else: - fi.interpolate(Constant(1.0)) + one.interpolate(Constant(np.ones(one.ufl_shape))) return one @@ -155,24 +131,6 @@ def test_zero_form(M, f, one): assert abs(zero_form - 0.5 * np.prod(f.ufl_shape)) < 1.0e-12 -@pytest.mark.xfail(reason="action(M, M) raises primal-dual TypeError") -def test_preprocess_form(M, a, f): - from ufl.algorithms import expand_indices, expand_derivatives - - expr = action(action(M, M), f) - A = BaseFormAssembler.preprocess_base_form(expr) - B = action(expand_derivatives(M), action(M, f)) - - assert isinstance(A, ufl.Action) - try: - # Need to expand indices to be able to match equal (different MultiIndex used for both). - assert expand_indices(A.left()) == expand_indices(B.left()) - assert expand_indices(A.right()) == expand_indices(B.right()) - except KeyError: - # Index expansion doesn't seem to play well with tensor elements. - pass - - def test_tensor_copy(a, M): # 1-form tensor From df04f4b90d962f0047f7b4d88f5e512833cec682 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 19 Dec 2024 19:16:03 -0600 Subject: [PATCH 10/31] only supply relevant kwargs to OneFormAssembler --- firedrake/assemble.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 875d27862e..afd7076114 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -143,18 +143,24 @@ def get_assembler(form, *args, **kwargs): """ is_base_form_preprocessed = kwargs.pop('is_base_form_preprocessed', False) + fc_params = kwargs.get('form_compiler_parameters', None) if isinstance(form, ufl.form.BaseForm) and not is_base_form_preprocessed: mat_type = kwargs.get('mat_type', None) - fc_params = kwargs.get('form_compiler_parameters', None) # Preprocess the DAG and restructure the DAG # Only pre-process `form` once beforehand to avoid pre-processing for each assembly call form = BaseFormAssembler.preprocess_base_form(form, mat_type=mat_type, form_compiler_parameters=fc_params) if isinstance(form, (ufl.form.Form, slate.TensorBase)) and not BaseFormAssembler.base_form_operands(form): diagonal = kwargs.pop('diagonal', False) if len(form.arguments()) == 0: - return ZeroFormAssembler(form, **kwargs) + return ZeroFormAssembler(form, form_compiler_parameters=fc_params) elif len(form.arguments()) == 1 or diagonal: - return OneFormAssembler(form, *args, diagonal=diagonal, **kwargs) + return OneFormAssembler(form, *args, + bcs=kwargs.get("bcs", None), + form_compiler_parameters=fc_params, + needs_zeroing=kwargs.get("needs_zeroing", True), + zero_bc_nodes=kwargs.get("zero_bc_nodes", True), + diagonal=diagonal, + weight=kwargs.get("weight", 1.0)) elif len(form.arguments()) == 2: return TwoFormAssembler(form, *args, **kwargs) else: @@ -1192,7 +1198,7 @@ def _apply_dirichlet_bc(self, tensor, bc): if self._diagonal: bc.set(tensor, self._weight) elif not self._zero_bc_nodes: - # NOTE this will only work if tensor is a Function and not a Cofunction + # NOTE this only works if tensor is a Function and not a Cofunction bc.apply(tensor) else: bc.zero(tensor) From 950e42d1130f8454eecebf07744821c9649507fa Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 20 Dec 2024 21:15:56 -0600 Subject: [PATCH 11/31] Only interpolate the residual, not every cofunction in the RHS --- firedrake/assemble.py | 12 +++++++++--- firedrake/variational_solver.py | 18 +++++++++--------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index afd7076114..a1d0009eb1 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -385,6 +385,12 @@ def visitor(e, *operands): visited = {} result = BaseFormAssembler.base_form_postorder_traversal(self._form, visitor, visited) + # Apply BCs after assembly + rank = len(self._form.arguments()) + if rank == 1: + for bc in self._bcs: + bc.zero(result) + if tensor: BaseFormAssembler.update_tensor(result, tensor) return tensor @@ -409,7 +415,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args): if rank == 0: assembler = ZeroFormAssembler(form, form_compiler_parameters=self._form_compiler_params) elif rank == 1 or (rank == 2 and self._diagonal): - assembler = OneFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params, + assembler = OneFormAssembler(form, form_compiler_parameters=self._form_compiler_params, zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight) elif rank == 2: assembler = TwoFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params, @@ -811,9 +817,9 @@ def restructure_base_form(expr, visited=None): return ufl.action(expr, ustar) # -- Case (6) -- # - if isinstance(expr, ufl.FormSum) and all(isinstance(c, ufl.core.base_form_operator.BaseFormOperator) for c in expr.components()): + if isinstance(expr, ufl.FormSum) and all(not isinstance(c, ufl.form.BaseForm) for c in expr.components()): # Return ufl.Sum - return sum([c for c in expr.components()]) + return sum(w*c for w, c in zip(expr.weights(), expr.components())) return expr @staticmethod diff --git a/firedrake/variational_solver.py b/firedrake/variational_solver.py index 488b42d55c..3c8fc8b930 100644 --- a/firedrake/variational_solver.py +++ b/firedrake/variational_solver.py @@ -9,11 +9,12 @@ PETSc, OptionsManager, flatten_parameters, DEFAULT_KSP_PARAMETERS, DEFAULT_SNES_PARAMETERS ) -from firedrake.function import Function, Cofunction -from firedrake.ufl_expr import TrialFunction, TestFunction +from firedrake.function import Function +from firedrake.ufl_expr import TrialFunction, TestFunction, action from firedrake.bcs import DirichletBC, EquationBC, extract_subdomain_ids, restricted_function_space from firedrake.adjoint_utils import NonlinearVariationalProblemMixin, NonlinearVariationalSolverMixin -from ufl import replace +from firedrake.__future__ import interpolate +from ufl import replace, Form __all__ = ["LinearVariationalProblem", "LinearVariationalSolver", @@ -91,12 +92,11 @@ def __init__(self, F, u, bcs=None, J=None, bcs = [bc.reconstruct(V=V_res, indices=bc._indices) for bc in bcs] self.u_restrict = Function(V_res).interpolate(u) v_res, u_res = TestFunction(V_res), TrialFunction(V_res) - F_arg, = F.arguments() - replace_F = {F_arg: v_res, self.u: self.u_restrict} - for c in F.coefficients(): - if c.function_space() == V.dual(): - replace_F[c] = Cofunction(V_res.dual()).interpolate(c) - self.F = replace(F, replace_F) + if isinstance(F, Form): + F_arg, = F.arguments() + self.F = replace(F, {F_arg: v_res, self.u: self.u_restrict}) + else: + self.F = action(replace(F, {self.u: self.u_restrict}), interpolate(v_res, V)) v_arg, u_arg = self.J.arguments() self.J = replace(self.J, {v_arg: v_res, u_arg: u_res, self.u: self.u_restrict}) if self.Jp: From a86f3f56f42847e3bd481af9beb381a19b435b85 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 20 Dec 2024 21:21:27 -0600 Subject: [PATCH 12/31] DROP BEFORE MERGE --- .github/workflows/build.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0eb616c24d..51bb17c2f1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -84,6 +84,7 @@ jobs: --install defcon \ --install gadopt \ --install asQ \ + --package-branch ufl pbrubeck/fix/formsum-weights \ || (cat firedrake-install.log && /bin/false) - name: Install test dependencies run: | From 337d0872148a93c31d5f9402b1857f019d21ec7c Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sat, 21 Dec 2024 08:49:32 -0600 Subject: [PATCH 13/31] Fix tests --- firedrake/assemble.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index a1d0009eb1..de7baf4f6d 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -817,8 +817,8 @@ def restructure_base_form(expr, visited=None): return ufl.action(expr, ustar) # -- Case (6) -- # - if isinstance(expr, ufl.FormSum) and all(not isinstance(c, ufl.form.BaseForm) for c in expr.components()): - # Return ufl.Sum + if isinstance(expr, ufl.FormSum) and all(ufl.duals.is_dual(a.function_space()) for a in expr.arguments()): + # Return ufl.Sum if we are assembling a FormSum with Coarguments (a primal expression) return sum(w*c for w, c in zip(expr.weights(), expr.components())) return expr From ed34164e29498a991c158159625a43be2d63c895 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sat, 21 Dec 2024 19:05:28 -0600 Subject: [PATCH 14/31] Fix adjoint utils --- firedrake/adjoint_utils/blocks/solving.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index 0e6adce8df..2cf6d9fd36 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -260,8 +260,11 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, return dFdm dFdm = -firedrake.derivative(F_form, c_rep, trial_function) - dFdm = firedrake.adjoint(dFdm) - dFdm = dFdm * adj_sol + if isinstance(dFdm, ufl.Form): + dFdm = firedrake.adjoint(dFdm) + dFdm = firedrake.action(dFdm, adj_sol) + else: + dFdm = dFdm(adj_sol) dFdm = firedrake.assemble(dFdm, **self.assemble_kwargs) return dFdm From 027ad371a2e53e232202c41b7cf88084dad25a91 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sat, 21 Dec 2024 21:28:25 -0600 Subject: [PATCH 15/31] More robust test for (unrestricted) Cofunction RHS --- .../regression/test_solving_interface.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/firedrake/regression/test_solving_interface.py b/tests/firedrake/regression/test_solving_interface.py index f32e6f3214..ed47d9e82e 100644 --- a/tests/firedrake/regression/test_solving_interface.py +++ b/tests/firedrake/regression/test_solving_interface.py @@ -222,22 +222,24 @@ def test_constant_jacobian_lvs(): def test_solve_cofunction_rhs(): - mesh = UnitSquareMesh(10, 10) + mesh = UnitIntervalMesh(10) V = FunctionSpace(mesh, "CG", 1) + x, = SpatialCoordinate(mesh) u = TrialFunction(V) v = TestFunction(V) - a = inner(u, v) * dx - + a = inner(grad(u), grad(v)) * dx L = Cofunction(V.dual()) - L.vector()[:] = 1. + bcs = [DirichletBC(V, x, "on_boundary")] + # Set the wrong BCs on the RHS + for bc in bcs: + bc.set(L, 888) + Lold = L.copy() w = Function(V) - solve(a == L, w) - - Aw = assemble(action(a, w)) - assert isinstance(Aw, Cofunction) - assert np.allclose(Aw.dat.data_ro, L.dat.data_ro) + solve(a == L, w, bcs=bcs) + assert errornorm(x, w) < 1E-10 + assert np.allclose(L.dat.data, Lold.dat.data) def test_solve_empty_form_rhs(): From 22865961201f2a68cab66ddcaea14f0742c5f071 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sun, 29 Dec 2024 22:03:29 -0600 Subject: [PATCH 16/31] DO NOT MERGE --- .github/workflows/build.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0eb616c24d..fa437c82e5 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -84,6 +84,7 @@ jobs: --install defcon \ --install gadopt \ --install asQ \ + --package-branch ufl pbrubeck/simplify-indexed \ || (cat firedrake-install.log && /bin/false) - name: Install test dependencies run: | From bb04bb00860b9f697df226b1ec1ce35f35e4c2bb Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 31 Dec 2024 20:01:24 -0600 Subject: [PATCH 17/31] Replace empty Jacobians with ZeroBaseForm --- firedrake/adjoint_utils/variational_solver.py | 8 +-- firedrake/assemble.py | 11 ++- firedrake/formmanipulation.py | 67 +++++++++++-------- firedrake/preconditioners/massinv.py | 2 +- firedrake/solving_utils.py | 9 ++- firedrake/tsfc_interface.py | 4 +- .../firedrake/slate/test_assemble_tensors.py | 17 +++-- 7 files changed, 70 insertions(+), 48 deletions(-) diff --git a/firedrake/adjoint_utils/variational_solver.py b/firedrake/adjoint_utils/variational_solver.py index c90d2668e0..79eb09096e 100644 --- a/firedrake/adjoint_utils/variational_solver.py +++ b/firedrake/adjoint_utils/variational_solver.py @@ -2,6 +2,7 @@ from functools import wraps from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape, no_annotations from firedrake.adjoint_utils.blocks import NonlinearVariationalSolveBlock +from firedrake.ufl_expr import derivative, adjoint from ufl import replace @@ -11,7 +12,6 @@ def _ad_annotate_init(init): @no_annotations @wraps(init) def wrapper(self, *args, **kwargs): - from firedrake import derivative, adjoint, TrialFunction init(self, *args, **kwargs) self._ad_F = self.F self._ad_u = self.u_restrict @@ -20,10 +20,8 @@ def wrapper(self, *args, **kwargs): try: # Some forms (e.g. SLATE tensors) are not currently # differentiable. - dFdu = derivative(self.F, - self.u_restrict, - TrialFunction(self.u_restrict.function_space())) - self._ad_adj_F = adjoint(dFdu) + dFdu = derivative(self.F, self.u_restrict) + self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True) except (TypeError, NotImplementedError): self._ad_adj_F = None self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear} diff --git a/firedrake/assemble.py b/firedrake/assemble.py index f3049ae01c..60c934b6c7 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -577,10 +577,15 @@ def base_form_assembly_visitor(self, expr, tensor, *args): @staticmethod def update_tensor(assembled_base_form, tensor): if isinstance(tensor, (firedrake.Function, firedrake.Cofunction)): - assembled_base_form.dat.copy(tensor.dat) + if isinstance(assembled_base_form, ufl.ZeroBaseForm): + tensor.dat.zero() + else: + assembled_base_form.dat.copy(tensor.dat) elif isinstance(tensor, matrix.MatrixBase): - # Uses the PETSc copy method. - assembled_base_form.petscmat.copy(tensor.petscmat) + if isinstance(assembled_base_form, ufl.ZeroBaseForm): + tensor.petscmat.zero() + else: + assembled_base_form.petscmat.copy(tensor.petscmat) else: raise NotImplementedError("Cannot update tensor of type %s" % type(tensor)) diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index 35a6789107..3179961df8 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -2,13 +2,29 @@ import numpy import collections -from ufl import as_vector +from ufl import as_vector, split, ZeroBaseForm from ufl.classes import Zero, FixedIndex, ListTensor from ufl.algorithms.map_integrands import map_integrand_dags +from ufl.algorithms import expand_derivatives from ufl.corealg.map_dag import MultiFunction, map_expr_dags from firedrake.petsc import PETSc from firedrake.ufl_expr import Argument +from firedrake.functionspace import MixedFunctionSpace, FunctionSpace + + +def subspace(V, indices): + try: + indices = tuple(indices) + except TypeError: + # Only one index provided. + indices = (indices, ) + if len(indices) == 1: + W = V[indices[0]] + W = FunctionSpace(W.mesh(), W.ufl_element()) + else: + W = MixedFunctionSpace([V[i] for i in indices]) + return W class ExtractSubBlock(MultiFunction): @@ -26,9 +42,11 @@ def indexed(self, o, child, multiindex): indices = multiindex.indices() if isinstance(child, ListTensor) and all(isinstance(i, FixedIndex) for i in indices): if len(indices) == 1: - return child.ufl_operands[indices[0]._value] + return child[indices[0]] + elif len(indices) == len(child.ufl_operands) and all(k == int(i) for k, i in enumerate(indices)): + return child else: - return ListTensor(*(child.ufl_operands[i._value] for i in multiindex.indices())) + return ListTensor(*(child[i] for i in indices)) return self.expr(o, child, multiindex) index_inliner = IndexInliner() @@ -57,6 +75,11 @@ def split(self, form, argument_indices): assert (idx[0] == 0 for idx in self.blocks.values()) return form f = map_integrand_dags(self, form) + f = expand_derivatives(f) + if f.empty(): + f = ZeroBaseForm(tuple(Argument(subspace(arg.function_space(), indices), + arg.number(), part=arg.part()) + for arg, indices in zip(form.arguments(), argument_indices))) return f expr = MultiFunction.reuse_if_untouched @@ -85,8 +108,6 @@ def coefficient_derivative(self, o, expr, coefficients, arguments, cds): @PETSc.Log.EventDecorator() def argument(self, o): - from ufl import split - from firedrake import MixedFunctionSpace, FunctionSpace V = o.function_space() if len(V) == 1: # Not on a mixed space, just return ourselves. @@ -95,36 +116,29 @@ def argument(self, o): if o in self._arg_cache: return self._arg_cache[o] - V_is = V.subfunctions indices = self.blocks[o.number()] try: indices = tuple(indices) - nidx = len(indices) except TypeError: # Only one index provided. indices = (indices, ) - nidx = 1 - if nidx == 1: - W = V_is[indices[0]] - W = FunctionSpace(W.mesh(), W.ufl_element()) - a = (Argument(W, o.number(), part=o.part()), ) - else: - W = MixedFunctionSpace([V_is[i] for i in indices]) - a = split(Argument(W, o.number(), part=o.part())) + W = subspace(V, indices) + a = Argument(W, o.number(), part=o.part()) + a = (a, ) if len(W) == 1 else split(a) + args = [] - for i in range(len(V_is)): + for i in range(len(V)): if i in indices: c = indices.index(i) a_ = a[c] if len(a_.ufl_shape) == 0: - args += [a_] + args.append(a_) else: - args += [a_[j] for j in numpy.ndindex(a_.ufl_shape)] + args.extend(a_[j] for j in numpy.ndindex(a_.ufl_shape)) else: - args += [Zero() - for j in numpy.ndindex(V_is[i].value_shape)] + args.extend(Zero() for j in numpy.ndindex(V[i].value_shape)) return self._arg_cache.setdefault(o, as_vector(args)) @@ -168,11 +182,10 @@ def split_form(form, diagonal=False): assert len(shape) == 2 for idx in numpy.ndindex(shape): f = splitter.split(form, idx) - if len(f.integrals()) > 0: - if diagonal: - i, j = idx - if i != j: - continue - idx = (i, ) - forms.append(SplitForm(indices=idx, form=f)) + if diagonal: + i, j = idx + if i != j: + continue + idx = (i, ) + forms.append(SplitForm(indices=idx, form=f)) return tuple(forms) diff --git a/firedrake/preconditioners/massinv.py b/firedrake/preconditioners/massinv.py index 92f286c708..d29c704e8b 100644 --- a/firedrake/preconditioners/massinv.py +++ b/firedrake/preconditioners/massinv.py @@ -20,7 +20,7 @@ class MassInvPC(AssembledPC): context, keyed on ``"mu"``. """ def form(self, pc, test, trial): - _, bcs = super(MassInvPC, self).form(pc, test, trial) + _, bcs = super(MassInvPC, self).form(pc) appctx = self.get_appctx(pc) mu = appctx.get("mu", 1.0) diff --git a/firedrake/solving_utils.py b/firedrake/solving_utils.py index 9e843016b5..789a6f1880 100644 --- a/firedrake/solving_utils.py +++ b/firedrake/solving_utils.py @@ -12,8 +12,8 @@ def _make_reasons(reasons): - return dict([(getattr(reasons, r), r) - for r in dir(reasons) if not r.startswith('_')]) + return {getattr(reasons, r): r + for r in dir(reasons) if not r.startswith('_')} KSPReasons = _make_reasons(PETSc.KSP.ConvergedReason()) @@ -333,7 +333,7 @@ def split(self, fields): # Split it apart to shove in the form. subsplit = split(subu) # Permutation from field indexing to indexing of pieces - field_renumbering = dict([f, i] for i, f in enumerate(field)) + field_renumbering = {f: i for i, f in enumerate(field)} vec = [] for i, u in enumerate(us): if i in field: @@ -344,8 +344,7 @@ def split(self, fields): if u.ufl_shape == (): vec.append(u) else: - for idx in numpy.ndindex(u.ufl_shape): - vec.append(u[idx]) + vec.extend(u[idx] for idx in numpy.ndindex(u.ufl_shape)) # So now we have a new representation for the solution # vector in the old problem. For the fields we're going diff --git a/firedrake/tsfc_interface.py b/firedrake/tsfc_interface.py index ba10d79507..1117f54bd4 100644 --- a/firedrake/tsfc_interface.py +++ b/firedrake/tsfc_interface.py @@ -11,7 +11,7 @@ import ufl import finat.ufl -from ufl import Form, conj +from ufl import conj, Form, ZeroBaseForm from .ufl_expr import TestFunction from tsfc import compile_form as original_tsfc_compile_form @@ -203,7 +203,7 @@ def compile_form(form, name, parameters=None, split=True, interface=None, diagon iterable = ([(None, )*nargs, form], ) for idx, f in iterable: f = _real_mangle(f) - if not f.integrals(): + if isinstance(f, ZeroBaseForm) or f.empty(): # If we're assembling the R space component of a mixed argument, # and that component doesn't actually appear in the form then we # have an empty form, which we should not attempt to assemble. diff --git a/tests/firedrake/slate/test_assemble_tensors.py b/tests/firedrake/slate/test_assemble_tensors.py index 5aff159b9b..c35d43e27e 100644 --- a/tests/firedrake/slate/test_assemble_tensors.py +++ b/tests/firedrake/slate/test_assemble_tensors.py @@ -249,9 +249,13 @@ def test_matrix_subblocks(mesh): refs = dict(split_form(A.form)) _A = A.blocks for x, y in indices: - ref = assemble(refs[x, y]).M.values block = _A[x, y] - assert np.allclose(assemble(block).M.values, ref, rtol=1e-14) + ref = refs[x, y] + if isinstance(ref, Form): + assert np.allclose(assemble(block).M.values, + assemble(ref).M.values, rtol=1e-14) + elif isinstance(ref, ZeroBaseForm): + assert block.form == ref # Mixed blocks A0101 = _A[:2, :2] @@ -280,9 +284,12 @@ def test_matrix_subblocks(mesh): (A1212_10, refs[(2, 1)])] # Test assembly of blocks of mixed blocks - for tensor, form in items: - ref = assemble(form).M.values - assert np.allclose(assemble(tensor).M.values, ref, rtol=1e-14) + for block, ref in items: + if isinstance(ref, Form): + assert np.allclose(assemble(block).M.values, + assemble(ref).M.values, rtol=1e-14) + elif isinstance(ref, ZeroBaseForm): + assert block.form == ref def test_diagonal(mass, matrix_mixed_nofacet): From af53302c7c9c3eeb185887e779b6cca0bead02e2 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 2 Jan 2025 10:07:05 -0600 Subject: [PATCH 18/31] Do not split off-diagonal blocks if we only want the diagonal --- firedrake/formmanipulation.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index 708adfd8e5..114651f793 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -196,14 +196,15 @@ def split_form(form, diagonal=False): args = form.arguments() shape = tuple(len(a.function_space()) for a in args) forms = [] + arity = len(shape) if diagonal: - assert len(shape) == 2 + assert arity == 2 + arity = 1 for idx in numpy.ndindex(shape): - f = splitter.split(form, idx) if diagonal: i, j = idx if i != j: continue - idx = (i, ) - forms.append(SplitForm(indices=idx, form=f)) + f = splitter.split(form, idx) + forms.append(SplitForm(indices=idx[:arity], form=f)) return tuple(forms) From 7f40504b440d8735c414f5f919c083be877e8da7 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 2 Jan 2025 20:29:13 -0600 Subject: [PATCH 19/31] Zero-simplify slate Tensors --- firedrake/slate/slac/tsfc_driver.py | 1 + firedrake/slate/slate.py | 45 ++++++++++++++++++++++------- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/firedrake/slate/slac/tsfc_driver.py b/firedrake/slate/slac/tsfc_driver.py index 0f5fbf96d3..136b2a0084 100644 --- a/firedrake/slate/slac/tsfc_driver.py +++ b/firedrake/slate/slac/tsfc_driver.py @@ -50,6 +50,7 @@ def compile_terminal_form(tensor, prefix, *, tsfc_parameters=None): assert tensor.terminal, ( "Only terminal tensors have forms associated with them!" ) + # Sets a default name for the subkernel prefix. mapper = RemoveRestrictions() integrals = map(partial(map_integrand_dags, mapper), diff --git a/firedrake/slate/slate.py b/firedrake/slate/slate.py index 85b7af3635..fd9535c31a 100644 --- a/firedrake/slate/slate.py +++ b/firedrake/slate/slate.py @@ -32,7 +32,7 @@ from ufl.corealg.multifunction import MultiFunction from ufl.classes import Zero from ufl.domain import join_domains, sort_domains -from ufl.form import Form +from ufl.form import Form, ZeroBaseForm import hashlib from firedrake.formmanipulation import ExtractSubBlock @@ -237,7 +237,7 @@ def coeff_map(self): coeff_map[m].update(c.indices[0]) else: m = self.coefficients().index(c) - split_map = tuple(range(len(c.subfunctions))) if isinstance(c, Function) or isinstance(c, Constant) or isinstance(c, Cofunction) else tuple(range(1)) + split_map = tuple(range(len(c.subfunctions))) if isinstance(c, (Function, Constant, Cofunction)) else (0,) coeff_map[m].update(split_map) return tuple((k, tuple(sorted(v)))for k, v in coeff_map.items()) @@ -382,6 +382,10 @@ def __eq__(self, other): """Determines whether two TensorBase objects are equal using their associated keys. """ + if isinstance(other, (int, float)) and other == 0: + if isinstance(self, Tensor): + return isinstance(self.form, ZeroBaseForm) or self.form.empty() + return False return self._key == other._key def __ne__(self, other): @@ -650,7 +654,7 @@ def __init__(self, tensor, indices): """Constructor for the Block class.""" super(Block, self).__init__() self.operands = (tensor,) - self._blocks = dict(enumerate(indices)) + self._blocks = dict(enumerate(map(as_tuple, indices))) self._indices = indices @cached_property @@ -671,14 +675,12 @@ def _split_arguments(self): nargs = [] for i, arg in enumerate(tensor.arguments()): V = arg.function_space() - V_is = V.subfunctions - idx = as_tuple(self._blocks[i]) + idx = self._blocks[i] if len(idx) == 1: - fidx, = idx - W = V_is[fidx] + W = V[idx[0]] W = FunctionSpace(W.mesh(), W.ufl_element()) else: - W = MixedFunctionSpace([V_is[fidx] for fidx in idx]) + W = MixedFunctionSpace([V[fidx] for fidx in idx]) nargs.append(Argument(W, arg.number(), part=arg.part())) @@ -880,7 +882,7 @@ class Tensor(TensorBase): def __init__(self, form, diagonal=False): """Constructor for the Tensor class.""" - if not isinstance(form, Form): + if not isinstance(form, (Form, ZeroBaseForm)): if isinstance(form, Function): raise TypeError("Use AssembledVector instead of Tensor.") raise TypeError("Only UFL forms are acceptable inputs.") @@ -1103,6 +1105,10 @@ def _output_string(self, prec=None): class Transpose(UnaryOp): """An abstract Slate class representing the transpose of a tensor.""" + def __new__(cls, A): + if A == 0: + return Tensor(ZeroBaseForm(A.form.arguments()[::-1])) + return BinaryOp.__new__(cls) @cached_property def arg_function_spaces(self): @@ -1127,6 +1133,10 @@ def _output_string(self, prec=None): class Negative(UnaryOp): """Abstract Slate class representing the negation of a tensor object.""" + def __new__(cls, A): + if A == 0: + return A + return BinaryOp.__new__(cls) @cached_property def arg_function_spaces(self): @@ -1197,6 +1207,12 @@ class Add(BinaryOp): :arg A: a :class:`~.firedrake.slate.TensorBase` object. :arg B: another :class:`~.firedrake.slate.TensorBase` object. """ + def __new__(cls, A, B): + if A == 0: + return B + elif B == 0: + return A + return BinaryOp.__new__(cls) def __init__(self, A, B): """Constructor for the Add class.""" @@ -1238,6 +1254,10 @@ class Mul(BinaryOp): :arg A: a :class:`~.firedrake.slate.TensorBase` object. :arg B: another :class:`~.firedrake.slate.TensorBase` object. """ + def __new__(cls, A, B): + if A == 0 or B == 0: + return Tensor(ZeroBaseForm(A.arguments()[:-1] + B.arguments()[1:])) + return BinaryOp.__new__(cls) def __init__(self, A, B): """Constructor for the Mul class.""" @@ -1295,7 +1315,7 @@ def __new__(cls, A, B, decomposition=None): raise ValueError("Illegal op on a %s-tensor with a %s-tensor." % (A.shape, B.shape)) - fsA = A.arg_function_spaces[::-1][-1] + fsA = A.arg_function_spaces[0] fsB = B.arg_function_spaces[0] assert space_equivalence(fsA, fsB), ( @@ -1348,6 +1368,11 @@ class DiagonalTensor(UnaryOp): """ diagonal = True + def __new__(cls, A): + if A == 0: + return Tensor(ZeroBaseForm(A.arguments()[:1])) + return BinaryOp.__new__(cls) + def __init__(self, A): """Constructor for the Diagonal class.""" assert A.rank == 2, "The tensor must be rank 2." From d68113f9721a60c7a9d460b9d83338a4a6ebb5fa Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 3 Jan 2025 11:31:17 -0600 Subject: [PATCH 20/31] set bcs directly on diagonal Cofunction --- firedrake/matrix_free/operators.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/firedrake/matrix_free/operators.py b/firedrake/matrix_free/operators.py index e31625354a..e9ecf63a30 100644 --- a/firedrake/matrix_free/operators.py +++ b/firedrake/matrix_free/operators.py @@ -183,11 +183,9 @@ def _assemble_diagonal(self): def getDiagonal(self, mat, vec): self._assemble_diagonal(tensor=self._diagonal) - diagonal_func = self._diagonal.riesz_representation(riesz_map="l2") for bc in self.bcs: # Operator is identity on boundary nodes - bc.set(diagonal_func, 1) - self._diagonal.assign(diagonal_func.riesz_representation(riesz_map="l2")) + bc.set(self._diagonal, 1) with self._diagonal.dat.vec_ro as v: v.copy(vec) From 3d06fc56e58302c9ff179eb9890986fe9e4f50ee Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 3 Jan 2025 11:26:48 -0600 Subject: [PATCH 21/31] ImplicitMatrixContext: handle empty action --- firedrake/assemble.py | 13 ++++++++----- firedrake/matrix_free/operators.py | 28 +++++++++++++++++----------- firedrake/slate/slate.py | 19 +++++++++++-------- 3 files changed, 36 insertions(+), 24 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 60c934b6c7..88d00c6db8 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -311,7 +311,8 @@ def __init__(self, zero_bc_nodes=False, diagonal=False, weight=1.0, - allocation_integral_types=None): + allocation_integral_types=None, + needs_zeroing=False): super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters) self._mat_type = mat_type self._sub_mat_type = sub_mat_type @@ -321,6 +322,7 @@ def __init__(self, self._diagonal = diagonal self._weight = weight self._allocation_integral_types = allocation_integral_types + assert not needs_zeroing def allocate(self): rank = len(self._form.arguments()) @@ -1127,7 +1129,8 @@ def _apply_bc(self, tensor, bc): pass def _check_tensor(self, tensor): - pass + if not isinstance(tensor, op2.Global): + raise TypeError(f"Expecting a op2.Global, got {tensor!r}.") @staticmethod def _as_pyop2_type(tensor, indices=None): @@ -1143,7 +1146,7 @@ class OneFormAssembler(ParloopFormAssembler): Parameters ---------- - form : ufl.Form or slate.TensorBasehe + form : ufl.Form or slate.TensorBase 1-form. Notes @@ -1189,8 +1192,8 @@ def _apply_bc(self, tensor, bc): self._apply_dirichlet_bc(tensor, bc) elif isinstance(bc, EquationBCSplit): bc.zero(tensor) - type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, - zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) + get_assembler(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, + zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) else: raise AssertionError diff --git a/firedrake/matrix_free/operators.py b/firedrake/matrix_free/operators.py index 3ee448730e..b111cbde76 100644 --- a/firedrake/matrix_free/operators.py +++ b/firedrake/matrix_free/operators.py @@ -10,6 +10,9 @@ from firedrake.bcs import DirichletBC, EquationBCSplit from firedrake.petsc import PETSc from firedrake.utils import cached_property +from firedrake.function import Function +from firedrake.cofunction import Cofunction +from ufl.form import ZeroBaseForm __all__ = ("ImplicitMatrixContext", ) @@ -107,23 +110,22 @@ def __init__(self, a, row_bcs=[], col_bcs=[], # create functions from test and trial space to help # with 1-form assembly - test_space, trial_space = [ - a.arguments()[i].function_space() for i in (0, 1) - ] - from firedrake import function, cofunction + test_space, trial_space = ( + arg.function_space() for arg in a.arguments() + ) # Need a cofunction since y receives the assembled result of Ax - self._ystar = cofunction.Cofunction(test_space.dual()) - self._y = function.Function(test_space) - self._x = function.Function(trial_space) - self._xstar = cofunction.Cofunction(trial_space.dual()) + self._ystar = Cofunction(test_space.dual()) + self._y = Function(test_space) + self._x = Function(trial_space) + self._xstar = Cofunction(trial_space.dual()) # These are temporary storage for holding the BC # values during matvec application. _xbc is for # the action and ._ybc is for transpose. if len(self.bcs) > 0: - self._xbc = cofunction.Cofunction(trial_space.dual()) + self._xbc = Cofunction(trial_space.dual()) if len(self.col_bcs) > 0: - self._ybc = cofunction.Cofunction(test_space.dual()) + self._ybc = Cofunction(test_space.dual()) # Get size information from template vecs on test and trial spaces trial_vec = trial_space.dof_dset.layout_vec @@ -135,6 +137,11 @@ def __init__(self, a, row_bcs=[], col_bcs=[], self.action = action(self.a, self._x) self.actionT = action(self.aT, self._y) + # TODO prevent action from returning empty Forms + if self.action.empty(): + self.action = ZeroBaseForm(self.a.arguments()[:-1]) + if self.actionT.empty(): + self.actionT = ZeroBaseForm(self.aT.arguments()[:-1]) # For assembling action(f, self._x) self.bcs_action = [] @@ -170,7 +177,6 @@ def __init__(self, a, row_bcs=[], col_bcs=[], @cached_property def _diagonal(self): - from firedrake import Cofunction assert self.on_diag return Cofunction(self._x.function_space().dual()) diff --git a/firedrake/slate/slate.py b/firedrake/slate/slate.py index fd9535c31a..1a8c792414 100644 --- a/firedrake/slate/slate.py +++ b/firedrake/slate/slate.py @@ -21,7 +21,10 @@ from ufl import Constant from ufl.coefficient import BaseCoefficient +from firedrake.formmanipulation import ExtractSubBlock from firedrake.function import Function, Cofunction +from firedrake.functionspace import FunctionSpace, MixedFunctionSpace +from firedrake.ufl_expr import Argument, TestFunction from firedrake.utils import cached_property, unique from itertools import chain, count @@ -35,8 +38,6 @@ from ufl.form import Form, ZeroBaseForm import hashlib -from firedrake.formmanipulation import ExtractSubBlock - from tsfc.ufl_utils import extract_firedrake_constants @@ -293,6 +294,10 @@ def solve(self, B, decomposition=None): """ return Solve(self, B, decomposition=decomposition) + def empty(self): + """Returns whether the form associated with the tensor is empty.""" + return False + @cached_property def blocks(self): """Returns an object containing the blocks of the tensor defined @@ -461,8 +466,6 @@ def arg_function_spaces(self): @cached_property def _argument(self): """Generates a 'test function' associated with this class.""" - from firedrake.ufl_expr import TestFunction - V, = self.arg_function_spaces return TestFunction(V) @@ -543,7 +546,6 @@ def arg_function_spaces(self): @cached_property def _argument(self): """Generates a tuple of 'test function' associated with this class.""" - from firedrake.ufl_expr import TestFunction return tuple(TestFunction(fs) for fs in self.arg_function_spaces) def arguments(self): @@ -668,9 +670,6 @@ def _split_arguments(self): """Splits the function space and stores the component spaces determined by the indices. """ - from firedrake.functionspace import FunctionSpace, MixedFunctionSpace - from firedrake.ufl_expr import Argument - tensor, = self.operands nargs = [] for i, arg in enumerate(tensor.arguments()): @@ -938,6 +937,10 @@ def subdomain_data(self): """ return self.form.subdomain_data() + def empty(self): + """Returns whether the form associated with the tensor is empty.""" + return self.form.empty() + def _output_string(self, prec=None): """Creates a string representation of the tensor.""" return ["S", "V", "M"][self.rank] + "_%d" % self.id From 6078f93243c1655d923c6aa7fc81e2fdeffd0e8d Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 3 Jan 2025 20:27:42 -0600 Subject: [PATCH 22/31] Only extract constants referenced in the kernel --- firedrake/assemble.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 88d00c6db8..dafe2a32ab 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -585,7 +585,7 @@ def update_tensor(assembled_base_form, tensor): assembled_base_form.dat.copy(tensor.dat) elif isinstance(tensor, matrix.MatrixBase): if isinstance(assembled_base_form, ufl.ZeroBaseForm): - tensor.petscmat.zero() + tensor.petscmat.zeroEntries() else: assembled_base_form.petscmat.copy(tensor.petscmat) else: @@ -2135,14 +2135,13 @@ def iter_active_coefficients(form, kinfo): @staticmethod def iter_constants(form, kinfo): - """Yield the form constants""" + """Yield the form constants referenced in ``kinfo``.""" if isinstance(form, slate.TensorBase): - for const in form.constants(): - yield const + all_constants = form.constants() else: all_constants = extract_firedrake_constants(form) - for constant_index in kinfo.constant_numbers: - yield all_constants[constant_index] + for constant_index in kinfo.constant_numbers: + yield all_constants[constant_index] @staticmethod def index_function_spaces(form, indices): From 5894b490cecb51532dab35d13b243fa8b4507f44 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 3 Jan 2025 20:37:40 -0600 Subject: [PATCH 23/31] Adjoint: only skip expand_derivatives if necessary --- firedrake/adjoint_utils/variational_solver.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/firedrake/adjoint_utils/variational_solver.py b/firedrake/adjoint_utils/variational_solver.py index 79eb09096e..c191308adc 100644 --- a/firedrake/adjoint_utils/variational_solver.py +++ b/firedrake/adjoint_utils/variational_solver.py @@ -21,7 +21,12 @@ def wrapper(self, *args, **kwargs): # Some forms (e.g. SLATE tensors) are not currently # differentiable. dFdu = derivative(self.F, self.u_restrict) - self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True) + try: + self._ad_adj_F = adjoint(dFdu) + except ValueError: + # Try again without expanding derivatives, + # as dFdu might have been simplied to an empty Form + self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True) except (TypeError, NotImplementedError): self._ad_adj_F = None self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear} From d99ba50b42874f0e61381c5c6ff679064287616b Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sat, 4 Jan 2025 17:48:35 -0600 Subject: [PATCH 24/31] style --- firedrake/formmanipulation.py | 15 ++++++--------- firedrake/slate/slac/tsfc_driver.py | 1 - 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index 114651f793..5e92bd8e8a 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -144,11 +144,8 @@ def cofunction(self, o): # Not on a mixed space, just return ourselves. return o - try: - indices, = set(self.blocks.values()) - except ValueError: - raise ValueError("Cofunction found on an off-diagonal block") - + # We only need the test space for Cofunction  + indices = self.blocks[0] if len(indices) == 1: i = indices[0] W = V[i] @@ -196,15 +193,15 @@ def split_form(form, diagonal=False): args = form.arguments() shape = tuple(len(a.function_space()) for a in args) forms = [] - arity = len(shape) + rank = len(shape) if diagonal: - assert arity == 2 - arity = 1 + assert rank == 2 + rank = 1 for idx in numpy.ndindex(shape): if diagonal: i, j = idx if i != j: continue f = splitter.split(form, idx) - forms.append(SplitForm(indices=idx[:arity], form=f)) + forms.append(SplitForm(indices=idx[:rank], form=f)) return tuple(forms) diff --git a/firedrake/slate/slac/tsfc_driver.py b/firedrake/slate/slac/tsfc_driver.py index 136b2a0084..0f5fbf96d3 100644 --- a/firedrake/slate/slac/tsfc_driver.py +++ b/firedrake/slate/slac/tsfc_driver.py @@ -50,7 +50,6 @@ def compile_terminal_form(tensor, prefix, *, tsfc_parameters=None): assert tensor.terminal, ( "Only terminal tensors have forms associated with them!" ) - # Sets a default name for the subkernel prefix. mapper = RemoveRestrictions() integrals = map(partial(map_integrand_dags, mapper), From d6bb7dd76ae1fed2a6b4ff6ccb5195516d104869 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sat, 4 Jan 2025 20:54:36 -0600 Subject: [PATCH 25/31] EquationBC: do not reconstruct empty Forms --- firedrake/assemble.py | 11 ++++------- firedrake/bcs.py | 8 ++++---- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index dafe2a32ab..61909e9955 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -311,8 +311,7 @@ def __init__(self, zero_bc_nodes=False, diagonal=False, weight=1.0, - allocation_integral_types=None, - needs_zeroing=False): + allocation_integral_types=None): super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters) self._mat_type = mat_type self._sub_mat_type = sub_mat_type @@ -322,7 +321,6 @@ def __init__(self, self._diagonal = diagonal self._weight = weight self._allocation_integral_types = allocation_integral_types - assert not needs_zeroing def allocate(self): rank = len(self._form.arguments()) @@ -1129,8 +1127,7 @@ def _apply_bc(self, tensor, bc): pass def _check_tensor(self, tensor): - if not isinstance(tensor, op2.Global): - raise TypeError(f"Expecting a op2.Global, got {tensor!r}.") + pass @staticmethod def _as_pyop2_type(tensor, indices=None): @@ -1192,8 +1189,8 @@ def _apply_bc(self, tensor, bc): self._apply_dirichlet_bc(tensor, bc) elif isinstance(bc, EquationBCSplit): bc.zero(tensor) - get_assembler(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, - zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) + type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, + zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) else: raise AssertionError diff --git a/firedrake/bcs.py b/firedrake/bcs.py index f0d007ede4..7c6821b3e3 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -634,10 +634,10 @@ def reconstruct(self, field=None, V=None, subu=None, u=None, row_field=None, col return rank = len(self.f.arguments()) splitter = ExtractSubBlock() - if rank == 1: - form = splitter.split(self.f, argument_indices=(row_field, )) - elif rank == 2: - form = splitter.split(self.f, argument_indices=(row_field, col_field)) + form = splitter.split(self.f, argument_indices=(row_field, col_field)[:rank]) + if form == 0: + # form is empty, do nothing + return if u is not None: form = firedrake.replace(form, {self.u: u}) if action_x is not None: From ed584675ee3c795ae4b6ea6606dadb745515cc03 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 6 Jan 2025 17:33:43 -0600 Subject: [PATCH 26/31] lower degree for EquationBC tests --- .../equation_bcs/test_equation_bcs.py | 80 +++++++++---------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/tests/firedrake/equation_bcs/test_equation_bcs.py b/tests/firedrake/equation_bcs/test_equation_bcs.py index 087b07aa36..3929eeaddb 100644 --- a/tests/firedrake/equation_bcs/test_equation_bcs.py +++ b/tests/firedrake/equation_bcs/test_equation_bcs.py @@ -17,15 +17,13 @@ def nonlinear_poisson(solver_parameters, mesh_num, porder): u = Function(V) v = TestFunction(V) - f = Function(V) x, y = SpatialCoordinate(mesh) - f.interpolate(- 8.0 * pi * pi * cos(x * pi * 2) * cos(y * pi * 2)) + f = - 8.0 * pi * pi * cos(x * pi * 2) * cos(y * pi * 2) a = - inner(grad(u), grad(v)) * dx L = inner(f, v) * dx - g = Function(V) - g.interpolate(cos(2 * pi * x) * cos(2 * pi * y)) + g = cos(2 * pi * x) * cos(2 * pi * y) # Equivalent to bc1 = EquationBC(v * (u - g1) * ds(1) == 0, u, 1) e2 = as_vector([0., 1.]) @@ -33,7 +31,7 @@ def nonlinear_poisson(solver_parameters, mesh_num, porder): solve(a - L == 0, u, bcs=[bc1], solver_parameters=solver_parameters) - f.interpolate(cos(x * pi * 2) * cos(y * pi * 2)) + f = cos(x * pi * 2) * cos(y * pi * 2) return sqrt(assemble(inner(u - f, u - f) * dx)) @@ -46,15 +44,13 @@ def linear_poisson(solver_parameters, mesh_num, porder): u = TrialFunction(V) v = TestFunction(V) - f = Function(V) x, y = SpatialCoordinate(mesh) - f.interpolate(- 8.0 * pi * pi * cos(x * pi * 2) * cos(y * pi * 2)) + f = - 8.0 * pi * pi * cos(x * pi * 2) * cos(y * pi * 2) a = - inner(grad(u), grad(v)) * dx L = inner(f, v) * dx - g = Function(V) - g.interpolate(cos(2 * pi * x) * cos(2 * pi * y)) + g = cos(2 * pi * x) * cos(2 * pi * y) u_ = Function(V) @@ -62,7 +58,7 @@ def linear_poisson(solver_parameters, mesh_num, porder): solve(a == L, u_, bcs=[bc1], solver_parameters=solver_parameters) - f.interpolate(cos(x * pi * 2) * cos(y * pi * 2)) + f = cos(x * pi * 2) * cos(y * pi * 2) return sqrt(assemble(inner(u_ - f, u_ - f) * dx)) @@ -75,9 +71,8 @@ def nonlinear_poisson_bbc(solver_parameters, mesh_num, porder): u = Function(V) v = TestFunction(V) - f = Function(V) x, y = SpatialCoordinate(mesh) - f.interpolate(- 8.0 * pi * pi * cos(x * pi * 2)*cos(y * pi * 2)) + f = - 8.0 * pi * pi * cos(x * pi * 2)*cos(y * pi * 2) a = - inner(grad(u), grad(v)) * dx L = inner(f, v) * dx @@ -85,13 +80,13 @@ def nonlinear_poisson_bbc(solver_parameters, mesh_num, porder): e2 = as_vector([0., 1.]) a1 = (-inner(dot(grad(u), e2), dot(grad(v), e2)) + 4 * pi * pi * inner(u, v)) * ds(1) - g = Function(V).interpolate(cos(2 * pi * x) * cos(2 * pi * y)) + g = cos(2 * pi * x) * cos(2 * pi * y) bbc = DirichletBC(V, g, ((1, 3), (1, 4))) bc1 = EquationBC(a1 == 0, u, 1, bcs=[bbc]) solve(a - L == 0, u, bcs=[bc1], solver_parameters=solver_parameters) - f.interpolate(cos(x * pi * 2) * cos(y * pi * 2)) + f = cos(x * pi * 2) * cos(y * pi * 2) return sqrt(assemble(inner(u - f, u - f) * dx)) @@ -104,9 +99,8 @@ def linear_poisson_bbc(solver_parameters, mesh_num, porder): u = TrialFunction(V) v = TestFunction(V) - f = Function(V) x, y = SpatialCoordinate(mesh) - f.interpolate(- 8.0 * pi * pi * cos(x * pi * 2)*cos(y * pi * 2)) + f = - 8.0 * pi * pi * cos(x * pi * 2)*cos(y * pi * 2) a = - inner(grad(u), grad(v)) * dx L = inner(f, v) * dx @@ -117,13 +111,13 @@ def linear_poisson_bbc(solver_parameters, mesh_num, porder): u = Function(V) - g = Function(V).interpolate(cos(2 * pi * x) * cos(2 * pi * y)) + g = cos(2 * pi * x) * cos(2 * pi * y) bbc = DirichletBC(V, g, ((1, 3), (1, 4))) bc1 = EquationBC(a1 == L1, u, 1, bcs=[bbc]) solve(a == L, u, bcs=[bc1], solver_parameters=solver_parameters) - f.interpolate(cos(x * pi * 2)*cos(y * pi * 2)) + f = cos(x * pi * 2)*cos(y * pi * 2) return sqrt(assemble(inner(u - f, u - f) * dx)) @@ -141,22 +135,25 @@ def nonlinear_poisson_mixed(solver_parameters, mesh_num, porder): n = FacetNormal(mesh) x, y = SpatialCoordinate(mesh) - f = Function(DG).interpolate(-8 * pi * pi * cos(2 * pi * x + pi / 3) * cos(2 * pi * y)) - u1 = Function(DG).interpolate(cos(2 * pi * y) / 2) + f = -8 * pi * pi * cos(2 * pi * x + pi / 3) * cos(2 * pi * y) + u1 = cos(2 * pi * y) / 2 - a = (inner(sigma, tau) + inner(u, div(tau)) + inner(div(sigma), v)) * dx + a = inner(sigma, tau) * dx + inner(u, div(tau)) * dx + inner(div(sigma), v) * dx L = inner(u1, dot(tau, n)) * ds(1) + inner(f, v) * dx - g = Function(BDM).project(as_vector([-2 * pi * sin(2 * pi * x + pi / 3) * cos(2 * pi * y), -2 * pi * cos(2 * pi * x + pi / 3) * sin(2 * pi * y)])) + g = as_vector([-2 * pi * sin(2 * pi * x + pi / 3) * cos(2 * pi * y), -2 * pi * cos(2 * pi * x + pi / 3) * sin(2 * pi * y)]) - bc2 = EquationBC(inner((dot(sigma, n) - dot(g, n)), dot(tau, n)) * ds(2) == 0, w, 2, V=W.sub(0)) - bc3 = EquationBC(inner((dot(sigma, n) - dot(g, n)), dot(tau, n)) * ds(3) == 0, w, 3, V=W.sub(0)) + tau_n = dot(tau, n) + sig_n = dot(sigma, n) + g_n = dot(g, n) + bc2 = EquationBC(inner(sig_n - g_n, tau_n) * ds(2) == 0, w, 2, V=W.sub(0)) + bc3 = EquationBC(inner(sig_n - g_n, tau_n) * ds(3) == 0, w, 3, V=W.sub(0)) bc4 = DirichletBC(W.sub(0), g, 4) solve(a - L == 0, w, bcs=[bc2, bc3, bc4], solver_parameters=solver_parameters) - f.interpolate(cos(2 * pi * x + pi / 3) * cos(2 * pi * y)) - g = Function(BDM).project(as_vector([-2 * pi * sin(2 * pi * x + pi / 3) * cos(2 * pi * y), -2 * pi * cos(2 * pi * x + pi / 3) * sin(2 * pi * y)])) + f = cos(2 * pi * x + pi / 3) * cos(2 * pi * y) + g = as_vector([-2 * pi * sin(2 * pi * x + pi / 3) * cos(2 * pi * y), -2 * pi * cos(2 * pi * x + pi / 3) * sin(2 * pi * y)]) return sqrt(assemble(inner(u - f, u - f) * dx)), sqrt(assemble(inner(sigma - g, sigma - g) * dx)) @@ -173,28 +170,31 @@ def linear_poisson_mixed(solver_parameters, mesh_num, porder): tau, v = TestFunctions(W) x, y = SpatialCoordinate(mesh) - f = Function(DG).interpolate(-8 * pi * pi * cos(2 * pi * x + pi / 3) * cos(2 * pi * y)) - u1 = Function(DG).interpolate(cos(2 * pi * y) / 2) + f = -8 * pi * pi * cos(2 * pi * x + pi / 3) * cos(2 * pi * y) + u1 = cos(2 * pi * y) / 2 n = FacetNormal(mesh) a = (inner(sigma, tau) + inner(u, div(tau)) + inner(div(sigma), v)) * dx L = inner(u1, dot(tau, n)) * ds(1) + inner(f, v) * dx - g = Function(BDM).project(as_vector([-2 * pi * sin(2 * pi * x + pi / 3) * cos(2 * pi * y), -2 * pi * cos(2 * pi * x + pi / 3) * sin(2 * pi * y)])) + g = as_vector([-2 * pi * sin(2 * pi * x + pi / 3) * cos(2 * pi * y), -2 * pi * cos(2 * pi * x + pi / 3) * sin(2 * pi * y)]) w = Function(W) - bc2 = EquationBC(inner(n, tau) * inner(sigma, n) * ds(2) == inner(n, tau) * inner(g, n) * ds(2), w, 2, V=W.sub(0)) - bc3 = EquationBC(inner(n, tau) * inner(sigma, n) * ds(3) == inner(n, tau) * inner(g, n) * ds(3), w, 3, V=W.sub(0)) + tau_n = dot(tau, n) + sig_n = dot(sigma, n) + g_n = dot(g, n) + bc2 = EquationBC(inner(sig_n, tau_n) * ds(2) == inner(g_n, tau_n) * ds(2), w, 2, V=W.sub(0)) + bc3 = EquationBC(inner(sig_n, tau_n) * ds(3) == inner(g_n, tau_n) * ds(3), w, 3, V=W.sub(0)) bc4 = DirichletBC(W.sub(0), g, 4) solve(a == L, w, bcs=[bc2, bc3, bc4], solver_parameters=solver_parameters) - sigma, u = w.subfunctions - f.interpolate(cos(2 * pi * x + pi / 3) * cos(2 * pi * y)) - g = Function(BDM).project(as_vector([-2 * pi * sin(2 * pi * x + pi / 3) * cos(2 * pi * y), -2 * pi * cos(2 * pi * x + pi / 3) * sin(2 * pi * y)])) + f = cos(2 * pi * x + pi / 3) * cos(2 * pi * y) + g = as_vector([-2 * pi * sin(2 * pi * x + pi / 3) * cos(2 * pi * y), -2 * pi * cos(2 * pi * x + pi / 3) * sin(2 * pi * y)]) + sigma, u = w.subfunctions return sqrt(assemble(inner(u - f, u - f) * dx)), sqrt(assemble(inner(sigma - g, sigma - g) * dx)) @@ -202,7 +202,7 @@ def linear_poisson_mixed(solver_parameters, mesh_num, porder): @pytest.mark.parametrize("with_bbc", [False, True]) def test_EquationBC_poisson_matrix(eq_type, with_bbc): mat_type = "aij" - porder = 3 + porder = 2 # Test standard poisson with EquationBCs # aij @@ -235,7 +235,7 @@ def test_EquationBC_poisson_matrix(eq_type, with_bbc): def test_EquationBC_poisson_matfree(with_bbc): eq_type = "linear" mat_type = "matfree" - porder = 3 + porder = 2 # Test standard poisson with EquationBCs # matfree @@ -271,7 +271,7 @@ def test_EquationBC_poisson_matfree(with_bbc): @pytest.mark.parametrize("eq_type", ["linear", "nonlinear"]) def test_EquationBC_mixedpoisson_matrix(eq_type): mat_type = "aij" - porder = 2 + porder = 0 # Mixed poisson with EquationBCs # aij @@ -294,7 +294,7 @@ def test_EquationBC_mixedpoisson_matrix(eq_type): def test_EquationBC_mixedpoisson_matrix_fieldsplit(): mat_type = "aij" eq_type = "linear" - porder = 2 + porder = 0 # Mixed poisson with EquationBCs # aij with fieldsplit pc @@ -324,7 +324,7 @@ def test_EquationBC_mixedpoisson_matrix_fieldsplit(): def test_EquationBC_mixedpoisson_matfree_fieldsplit(): mat_type = "matfree" eq_type = "linear" - porder = 2 + porder = 0 # Mixed poisson with EquationBCs # matfree with fieldsplit pc @@ -366,7 +366,7 @@ def test_equation_bcs_pc(): v, w = split(TestFunction(V)) x, y = SpatialCoordinate(mesh) exact = cos(2 * pi * x) * cos(2 * pi * y) - g = Function(CG).interpolate(8 * pi**2 * exact) + g = 8 * pi**2 * exact F = inner(grad(u), grad(v)) * dx + inner(l, w) * dx - inner(g, v) * dx bc = EquationBC(inner((u - exact), v) * ds == 0, f, (1, 2, 3, 4), V=V.sub(0)) params = { From 2a0c03b244ef9f641fe959d026225d49dd007a14 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 6 Jan 2025 17:35:03 -0600 Subject: [PATCH 27/31] style --- firedrake/bcs.py | 2 +- firedrake/formmanipulation.py | 2 +- tests/firedrake/equation_bcs/test_equation_bcs.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/firedrake/bcs.py b/firedrake/bcs.py index 7c6821b3e3..5884907feb 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -635,7 +635,7 @@ def reconstruct(self, field=None, V=None, subu=None, u=None, row_field=None, col rank = len(self.f.arguments()) splitter = ExtractSubBlock() form = splitter.split(self.f, argument_indices=(row_field, col_field)[:rank]) - if form == 0: + if isinstance(form, ufl.ZeroBaseForm) or form.empty(): # form is empty, do nothing return if u is not None: diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index 5e92bd8e8a..97f2b3c43e 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -73,8 +73,8 @@ def split(self, form, argument_indices): assert (len(idx) == 1 for idx in self.blocks.values()) assert (idx[0] == 0 for idx in self.blocks.values()) return form - f = map_integrand_dags(self, form) # TODO find a way to distinguish empty Forms avoiding expand_derivatives + f = map_integrand_dags(self, form) if expand_derivatives(f).empty(): # Get ZeroBaseForm with the right shape f = ZeroBaseForm(tuple(Argument(subspace(arg.function_space(), diff --git a/tests/firedrake/equation_bcs/test_equation_bcs.py b/tests/firedrake/equation_bcs/test_equation_bcs.py index 3929eeaddb..fdd05b7f2e 100644 --- a/tests/firedrake/equation_bcs/test_equation_bcs.py +++ b/tests/firedrake/equation_bcs/test_equation_bcs.py @@ -190,7 +190,6 @@ def linear_poisson_mixed(solver_parameters, mesh_num, porder): solve(a == L, w, bcs=[bc2, bc3, bc4], solver_parameters=solver_parameters) - f = cos(2 * pi * x + pi / 3) * cos(2 * pi * y) g = as_vector([-2 * pi * sin(2 * pi * x + pi / 3) * cos(2 * pi * y), -2 * pi * cos(2 * pi * x + pi / 3) * sin(2 * pi * y)]) From 934ff6f14b95681fe421e3f3f70868521ffdc39e Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 9 Jan 2025 14:40:53 +0000 Subject: [PATCH 28/31] FunctionSpace: multiindex returns subspace --- firedrake/formmanipulation.py | 47 ++++++++++------------------------ firedrake/functionspaceimpl.py | 15 +++++++++-- firedrake/slate/slate.py | 40 +++++++++++++---------------- 3 files changed, 45 insertions(+), 57 deletions(-) diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index 97f2b3c43e..815ae360a2 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -12,18 +12,7 @@ from pyop2.utils import as_tuple from firedrake.petsc import PETSc -from firedrake.ufl_expr import Argument from firedrake.cofunction import Cofunction -from firedrake.functionspace import FunctionSpace, MixedFunctionSpace, DualSpace - - -def subspace(V, indices): - if len(indices) == 1: - W = V[indices[0]] - W = FunctionSpace(W.mesh(), W.ufl_element()) - else: - W = MixedFunctionSpace([V[i] for i in indices]) - return W class ExtractSubBlock(MultiFunction): @@ -50,6 +39,10 @@ def indexed(self, o, child, multiindex): index_inliner = IndexInliner() + def _subspace_argument(self, a): + return type(a)(a.function_space()[self.blocks[a.number()]].collapse(), + a.number(), part=a.part()) + @PETSc.Log.EventDecorator() def split(self, form, argument_indices): """Split a form. @@ -77,10 +70,7 @@ def split(self, form, argument_indices): f = map_integrand_dags(self, form) if expand_derivatives(f).empty(): # Get ZeroBaseForm with the right shape - f = ZeroBaseForm(tuple(Argument(subspace(arg.function_space(), - self.blocks[arg.number()]), - arg.number(), part=arg.part()) - for arg in form.arguments())) + f = ZeroBaseForm(tuple(map(self._subspace_argument, form.arguments()))) return f expr = MultiFunction.reuse_if_untouched @@ -120,19 +110,14 @@ def argument(self, o): indices = self.blocks[o.number()] - W = subspace(V, indices) - a = Argument(W, o.number(), part=o.part()) - a = (a, ) if len(W) == 1 else split(a) + a = self._subspace_argument(o) + asplit = (a, ) if len(indices) == 1 else split(a) args = [] for i in range(len(V)): if i in indices: - c = indices.index(i) - a_ = a[c] - if len(a_.ufl_shape) == 0: - args.append(a_) - else: - args.extend(a_[j] for j in numpy.ndindex(a_.ufl_shape)) + asub = asplit[indices.index(i)] + args.extend(asub[j] for j in numpy.ndindex(asub.ufl_shape)) else: args.extend(Zero() for j in numpy.ndindex(V[i].value_shape)) return self._arg_cache.setdefault(o, as_vector(args)) @@ -144,17 +129,13 @@ def cofunction(self, o): # Not on a mixed space, just return ourselves. return o - # We only need the test space for Cofunction  + # We only need the test space for Cofunction indices = self.blocks[0] - if len(indices) == 1: - i = indices[0] - W = V[i] - W = DualSpace(W.mesh(), W.ufl_element()) - c = Cofunction(W, val=o.dat[i]) + W = V[indices].collapse() + if len(W) == 1: + return Cofunction(W, val=o.dat[indices[0]]) else: - W = MixedFunctionSpace([V[i] for i in indices]) - c = Cofunction(W, val=MixedDat(o.dat[i] for i in indices)) - return c + return Cofunction(W, val=MixedDat(o.dat[i] for i in indices)) SplitForm = collections.namedtuple("SplitForm", ["indices", "form"]) diff --git a/firedrake/functionspaceimpl.py b/firedrake/functionspaceimpl.py index 8fc81244f7..0fdfbe8e20 100644 --- a/firedrake/functionspaceimpl.py +++ b/firedrake/functionspaceimpl.py @@ -321,6 +321,14 @@ def __iter__(self): return iter(self.subfunctions) def __getitem__(self, i): + from firedrake.functionspace import MixedFunctionSpace + if isinstance(i, (tuple, list)): + # Return a subspace + if len(i) == 1: + return self[i[0]] + else: + return MixedFunctionSpace([self[isub] for isub in i]) + return self.subfunctions[i] def __mul__(self, other): @@ -944,6 +952,9 @@ def __hash__(self): def local_to_global_map(self, bcs, lgmap=None): return lgmap or self.dof_dset.lgmap + def collapse(self): + return type(self)(self.function_space.collapse(), boundary_set=self.boundary_set) + class MixedFunctionSpace(object): r"""A function space on a mixed finite element. @@ -1236,16 +1247,16 @@ class ProxyRestrictedFunctionSpace(RestrictedFunctionSpace): r"""A :class:`RestrictedFunctionSpace` that one can attach extra properties to. :arg function_space: The function space to be restricted. - :kwarg name: The name of the restricted function space. :kwarg boundary_set: The boundary domains on which boundary conditions will be specified + :kwarg name: The name of the restricted function space. .. warning:: Users should not build a :class:`ProxyRestrictedFunctionSpace` directly, it is mostly used as an internal implementation detail. """ - def __new__(cls, function_space, name=None, boundary_set=frozenset()): + def __new__(cls, function_space, boundary_set=frozenset(), name=None): topology = function_space._mesh.topology self = super(ProxyRestrictedFunctionSpace, cls).__new__(cls) if function_space._mesh is not topology: diff --git a/firedrake/slate/slate.py b/firedrake/slate/slate.py index 1a8c792414..58e281368f 100644 --- a/firedrake/slate/slate.py +++ b/firedrake/slate/slate.py @@ -23,8 +23,7 @@ from firedrake.formmanipulation import ExtractSubBlock from firedrake.function import Function, Cofunction -from firedrake.functionspace import FunctionSpace, MixedFunctionSpace -from firedrake.ufl_expr import Argument, TestFunction +from firedrake.ufl_expr import TestFunction from firedrake.utils import cached_property, unique from itertools import chain, count @@ -35,7 +34,7 @@ from ufl.corealg.multifunction import MultiFunction from ufl.classes import Zero from ufl.domain import join_domains, sort_domains -from ufl.form import Form, ZeroBaseForm +from ufl.form import BaseForm, Form, ZeroBaseForm import hashlib from tsfc.ufl_utils import extract_firedrake_constants @@ -461,7 +460,11 @@ def arg_function_spaces(self): """Returns a tuple of function spaces that the tensor is defined on. """ - return (self._function.ufl_function_space(),) + tensor = self._function + if isinstance(tensor, BaseForm): + return tuple(a.function_space() for a in tensor.arguments()) + else: + return (tensor.ufl_function_space(),) @cached_property def _argument(self): @@ -671,19 +674,9 @@ def _split_arguments(self): spaces determined by the indices. """ tensor, = self.operands - nargs = [] - for i, arg in enumerate(tensor.arguments()): - V = arg.function_space() - idx = self._blocks[i] - if len(idx) == 1: - W = V[idx[0]] - W = FunctionSpace(W.mesh(), W.ufl_element()) - else: - W = MixedFunctionSpace([V[fidx] for fidx in idx]) - - nargs.append(Argument(W, arg.number(), part=arg.part())) - - return tuple(nargs) + return tuple(type(a)(a.function_space()[self._blocks[i]].collapse(), + a.number(), part=a.part()) + for i, a in enumerate(tensor.arguments())) @cached_property def arg_function_spaces(self): @@ -1110,7 +1103,10 @@ class Transpose(UnaryOp): """An abstract Slate class representing the transpose of a tensor.""" def __new__(cls, A): if A == 0: - return Tensor(ZeroBaseForm(A.form.arguments()[::-1])) + return Tensor(ZeroBaseForm(A.arguments()[::-1])) + if isinstance(A, Transpose): + tensor, = A.operands + return tensor return BinaryOp.__new__(cls) @cached_property @@ -1223,8 +1219,8 @@ def __init__(self, A, B): raise ValueError("Illegal op on a %s-tensor with a %s-tensor." % (A.shape, B.shape)) - assert all([space_equivalence(fsA, fsB) for fsA, fsB in - zip(A.arg_function_spaces, B.arg_function_spaces)]), ( + assert all(space_equivalence(fsA, fsB) for fsA, fsB in + zip(A.arg_function_spaces, B.arg_function_spaces)), ( "Function spaces associated with operands must match." ) @@ -1311,12 +1307,12 @@ class Solve(BinaryOp): def __new__(cls, A, B, decomposition=None): assert A.rank == 2, "Operator must be a matrix." + assert B.rank >= 1, "RHS must be a vector or matrix." # Same rules for performing multiplication on Slate tensors # applies here. if A.shape[1] != B.shape[0]: - raise ValueError("Illegal op on a %s-tensor with a %s-tensor." - % (A.shape, B.shape)) + raise ValueError(f"Illegal op on a {A.shape}-tensor with a {B.shape}-tensor.") fsA = A.arg_function_spaces[0] fsB = B.arg_function_spaces[0] From 8688f5ec41fcf7a5954f2b1ee9075dd806b364b7 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 10 Jan 2025 20:13:27 +0000 Subject: [PATCH 29/31] Revert WithGeometry.__getitem__ --- firedrake/formmanipulation.py | 14 ++++++++++++-- firedrake/functionspaceimpl.py | 8 -------- firedrake/slate/slate.py | 4 ++-- firedrake/solving_utils.py | 4 +--- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index 815ae360a2..1d2aa67aae 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -12,9 +12,19 @@ from pyop2.utils import as_tuple from firedrake.petsc import PETSc +from firedrake.functionspace import MixedFunctionSpace from firedrake.cofunction import Cofunction +def subspace(V, indices): + """Construct a collapsed subspace using components from V.""" + if len(indices) == 1: + W = V[indices[0]] + else: + W = MixedFunctionSpace([V[i] for i in indices]) + return W.collapse() + + class ExtractSubBlock(MultiFunction): """Extract a sub-block from a form.""" @@ -40,7 +50,7 @@ def indexed(self, o, child, multiindex): index_inliner = IndexInliner() def _subspace_argument(self, a): - return type(a)(a.function_space()[self.blocks[a.number()]].collapse(), + return type(a)(subspace(a.function_space(), self.blocks[a.number()]), a.number(), part=a.part()) @PETSc.Log.EventDecorator() @@ -131,7 +141,7 @@ def cofunction(self, o): # We only need the test space for Cofunction indices = self.blocks[0] - W = V[indices].collapse() + W = subspace(V, indices) if len(W) == 1: return Cofunction(W, val=o.dat[indices[0]]) else: diff --git a/firedrake/functionspaceimpl.py b/firedrake/functionspaceimpl.py index 0fdfbe8e20..65935e5faf 100644 --- a/firedrake/functionspaceimpl.py +++ b/firedrake/functionspaceimpl.py @@ -321,14 +321,6 @@ def __iter__(self): return iter(self.subfunctions) def __getitem__(self, i): - from firedrake.functionspace import MixedFunctionSpace - if isinstance(i, (tuple, list)): - # Return a subspace - if len(i) == 1: - return self[i[0]] - else: - return MixedFunctionSpace([self[isub] for isub in i]) - return self.subfunctions[i] def __mul__(self, other): diff --git a/firedrake/slate/slate.py b/firedrake/slate/slate.py index 58e281368f..bb6909128f 100644 --- a/firedrake/slate/slate.py +++ b/firedrake/slate/slate.py @@ -21,7 +21,7 @@ from ufl import Constant from ufl.coefficient import BaseCoefficient -from firedrake.formmanipulation import ExtractSubBlock +from firedrake.formmanipulation import ExtractSubBlock, subspace from firedrake.function import Function, Cofunction from firedrake.ufl_expr import TestFunction from firedrake.utils import cached_property, unique @@ -674,7 +674,7 @@ def _split_arguments(self): spaces determined by the indices. """ tensor, = self.operands - return tuple(type(a)(a.function_space()[self._blocks[i]].collapse(), + return tuple(type(a)(subspace(a.function_space(), self._blocks[i]), a.number(), part=a.part()) for i, a in enumerate(tensor.arguments())) diff --git a/firedrake/solving_utils.py b/firedrake/solving_utils.py index 789a6f1880..c6444f0daa 100644 --- a/firedrake/solving_utils.py +++ b/firedrake/solving_utils.py @@ -332,15 +332,13 @@ def split(self, fields): subu = function.Function(V, val=val) # Split it apart to shove in the form. subsplit = split(subu) - # Permutation from field indexing to indexing of pieces - field_renumbering = {f: i for i, f in enumerate(field)} vec = [] for i, u in enumerate(us): if i in field: # If this is a field we're keeping, get it from # the new function. Otherwise just point to the # old data. - u = subsplit[field_renumbering[i]] + u = subsplit[field.index(i)] if u.ufl_shape == (): vec.append(u) else: From f3c4ef62d3259de57ff413561555e3e5d1a0d8b1 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 15 Jan 2025 19:21:47 +0000 Subject: [PATCH 30/31] Do not zero a ZeroBaseForm --- firedrake/assemble.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index ec5e011260..d08e64a487 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -387,7 +387,7 @@ def visitor(e, *operands): # Apply BCs after assembly rank = len(self._form.arguments()) - if rank == 1: + if rank == 1 and not isinstance(result, ufl.ZeroBaseForm): for bc in self._bcs: bc.zero(result) From 8f7ca9b0ec494daec1851cb21578b4e3e5b64b9e Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 15 Jan 2025 20:42:42 +0000 Subject: [PATCH 31/31] Update .github/workflows/build.yml --- .github/workflows/build.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 16bb5f41cd..0eb616c24d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -84,7 +84,6 @@ jobs: --install defcon \ --install gadopt \ --install asQ \ - --package-branch ufl pbrubeck/merge-upstream \ || (cat firedrake-install.log && /bin/false) - name: Install test dependencies run: |