Skip to content

Commit

Permalink
slate.slate -> slate
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Feb 5, 2025
1 parent 9d34989 commit da40554
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
6 changes: 1 addition & 5 deletions firedrake/linear_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from firedrake.cofunction import Cofunction
from firedrake.vector import Vector
from firedrake.matrix import MatrixBase
from firedrake.slate import slate
from firedrake.petsc import PETSc
from pyop2.mpi import internal_comm
from firedrake.variational_solver import LinearVariationalProblem, LinearVariationalSolver
Expand Down Expand Up @@ -49,11 +48,8 @@ def __init__(self, A, *, P=None, **kwargs):
test, trial = A.a.arguments()
x = Function(trial.function_space())
b = Cofunction(test.function_space().dual())
L = b
if isinstance(A.a, slate.TensorBase):
L = slate.AssembledVector(b)

problem = LinearVariationalProblem(A, L, x, bcs=A.bcs, aP=P)
problem = LinearVariationalProblem(A, b, x, bcs=A.bcs, aP=P)
solver = LinearVariationalSolver(problem, **kwargs)
self.b = b
self.x = x
Expand Down
18 changes: 12 additions & 6 deletions firedrake/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from contextlib import ExitStack
from types import MappingProxyType

from firedrake import dmhooks, slate, solving, solving_utils, ufl_expr, utils
from firedrake import dmhooks, solving, solving_utils, ufl_expr, utils
from firedrake.slate import slate
from firedrake.petsc import (
PETSc, OptionsManager, flatten_parameters, DEFAULT_KSP_PARAMETERS,
DEFAULT_SNES_PARAMETERS
Expand All @@ -23,15 +24,15 @@


def check_pde_args(F, J, Jp):
if not isinstance(F, (ufl.BaseForm, slate.slate.TensorBase)):
if not isinstance(F, (ufl.BaseForm, slate.TensorBase)):
raise TypeError("Provided residual is a '%s', not a BaseForm or Slate Tensor" % type(F).__name__)
if len(F.arguments()) != 1:
raise ValueError("Provided residual is not a linear form")
if not isinstance(J, (ufl.BaseForm, slate.slate.TensorBase)):
if not isinstance(J, (ufl.BaseForm, slate.TensorBase)):
raise TypeError("Provided Jacobian is a '%s', not a BaseForm or Slate Tensor" % type(J).__name__)
if len(J.arguments()) != 2:
raise ValueError("Provided Jacobian is not a bilinear form")
if Jp is not None and not isinstance(Jp, (ufl.BaseForm, slate.slate.TensorBase)):
if Jp is not None and not isinstance(Jp, (ufl.BaseForm, slate.TensorBase)):
raise TypeError("Provided preconditioner is a '%s', not a BaseForm or Slate Tensor" % type(Jp).__name__)
if Jp is not None and len(Jp.arguments()) != 2:
raise ValueError("Provided preconditioner is not a bilinear form")
Expand Down Expand Up @@ -367,16 +368,21 @@ def __init__(self, a, L, u, bcs=None, aP=None,
# In the linear case, the Jacobian is the equation LHS.
J = a
# Jacobian is checked in superclass, but let's check L here.
if not isinstance(L, (ufl.BaseForm, slate.slate.TensorBase)) and L == 0:
if not isinstance(L, (ufl.BaseForm, slate.TensorBase)) and L == 0:
F = ufl_expr.action(J, u)
else:
if not isinstance(L, (ufl.BaseForm, slate.slate.TensorBase)):
if not isinstance(L, (ufl.BaseForm, slate.TensorBase)):
raise TypeError("Provided RHS is a '%s', not a Form or Slate Tensor" % type(L).__name__)
if len(L.arguments()) != 1 and not L.empty():
raise ValueError("Provided RHS is not a linear form")
A = J
if isinstance(A, MatrixBase):
A = A.a
if isinstance(A, slate.TensorBase) and isinstance(L, ufl.BaseForm):
if isinstance(L, ufl.Form):
L = slate.Tensor(L)
else:
L = slate.AssembledVector(L)
F = ufl_expr.action(A, u) - L

super(LinearVariationalProblem, self).__init__(F, u, bcs, J, aP,
Expand Down

0 comments on commit da40554

Please sign in to comment.