Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saddle point PC for WC-4DVar #4009

Draft
wants to merge 3 commits into
base: allatoncereducedfunctional
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 72 additions & 40 deletions firedrake/adjoint/fourdvar_reduced_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from pyadjoint.enlisting import Enlist
from firedrake.function import Function
from firedrake.ensemblefunction import EnsembleFunction, EnsembleCofunction
from firedrake import assemble, inner, dx, Constant
from firedrake.adjoint.composite_reduced_functional import (
CompositeReducedFunctional, tlm, hessian, intermediate_options)

from functools import wraps, cached_property
from typing import Callable, Optional
from functools import wraps, cached_property, partial
from typing import Callable, Optional, Collection, Union
from types import SimpleNamespace
from contextlib import contextmanager
from mpi4py import MPI
Expand Down Expand Up @@ -91,7 +91,7 @@ class FourDVarReducedFunctional(ReducedFunctional):
The :class:`.EnsembleFunction` for the control x_{i} at the initial condition
and at the end of each observation stage.

background_iprod
background_covariance
The inner product to calculate the background error functional
from the background error :math:`x_{0} - x_{b}`. Can include the
error covariance matrix. Only used on ensemble rank 0.
Expand All @@ -101,18 +101,18 @@ class FourDVarReducedFunctional(ReducedFunctional):
If not provided, the value of the first subfunction on the first ensemble
member of the control :class:`.EnsembleFunction` will be used.

observation_err
observation_error
Given a state :math:`x`, returns the observations error
:math:`y_{0} - \\mathcal{H}_{0}(x)` where :math:`y_{0}` are the
observations at the initial time and :math:`\\mathcal{H}_{0}` is
the observation operator for the initial time. Only used on
ensemble rank 0. Optional.

observation_iprod
observation_covariance
The inner product to calculate the observation error functional
from the observation error :math:`y_{0} - \\mathcal{H}_{0}(x)`.
Can include the error covariance matrix. Must be provided if
observation_err is provided. Only used on ensemble rank 0
observation_error is provided. Only used on ensemble rank 0

weak_constraint
Whether to use the weak or strong constraint 4DVar formulation.
Expand All @@ -123,18 +123,18 @@ class FourDVarReducedFunctional(ReducedFunctional):
"""

def __init__(self, control: Control,
background_iprod: Optional[Callable[[OverloadedType], AdjFloat]],
background_covariance: Union[Constant, tuple],
background: Optional[OverloadedType] = None,
observation_err: Optional[Callable[[OverloadedType], OverloadedType]] = None,
observation_iprod: Optional[Callable[[OverloadedType], AdjFloat]] = None,
observation_error: Optional[Callable[[OverloadedType], OverloadedType]] = None,
observation_covariance: Optional[Callable[[OverloadedType], AdjFloat]] = None,
weak_constraint: bool = True,
tape: Optional[Tape] = None,
_annotate_accumulation: bool = False):

self.tape = get_working_tape() if tape is None else tape

self.weak_constraint = weak_constraint
self.initial_observations = observation_err is not None
self.initial_observations = observation_error is not None

if self.weak_constraint:
self._annotate_accumulation = _annotate_accumulation
Expand Down Expand Up @@ -184,9 +184,9 @@ def __init__(self, control: Control,
control_name="Control_0_bkg_copy")

# RF to recalculate inner product |x_0 - x_b|_B
self.background_norm = isolated_rf(
operation=background_iprod,
control=self.background_error.functional,
self.background_norm = CovarianceNormReducedFunctional(
self.background_error.functional,
background_covariance,
control_name="bkg_err_vec_copy")

# compose background reduced functionals to evaluate both together
Expand All @@ -197,16 +197,16 @@ def __init__(self, control: Control,

# RF to recalculate error vector (H(x_0) - y_0)
self.initial_observation_error = isolated_rf(
operation=observation_err,
operation=observation_error,
control=_x[0],
functional_name="obs_err_vec_0",
control_name="Control_0_obs_copy")

# RF to recalculate inner product |H(x_0) - y_0|_R
self.initial_observation_norm = isolated_rf(
operation=observation_iprod,
control=self.initial_observation_error.functional,
functional_name="obs_err_vec_0_copy")
self.initial_observation_norm = CovarianceNormReducedFunctional(
self.initial_observation_error.functional,
observation_covariance,
control_name="obs_err_vec_0_copy")

# compose initial observation reduced functionals to evaluate both together
self.initial_observation_rf = CompositeReducedFunctional(
Expand Down Expand Up @@ -246,12 +246,14 @@ def __init__(self, control: Control,

# penalty for straying from prior
self._accumulate_functional(
background_iprod(control.control - self.background))
covariance_norm(control.control - self.background,
background_covariance))

# penalty for not hitting observations at initial time
if self.initial_observations:
self._accumulate_functional(
observation_iprod(observation_err(control.control)))
covariance_norm(observation_error(control.control),
observation_covariance))

@cached_property
def strong_reduced_functional(self):
Expand Down Expand Up @@ -745,8 +747,8 @@ def __init__(self, control: OverloadedType,
self.observation_index = observation_index

def set_observation(self, state: OverloadedType,
observation_err: Callable[[OverloadedType], OverloadedType],
observation_iprod: Callable[[OverloadedType], AdjFloat]):
observation_error: Callable[[OverloadedType], OverloadedType],
observation_covariance: Callable[[OverloadedType], AdjFloat]):
"""
Record an observation at the time of `state`.

Expand All @@ -756,14 +758,14 @@ def set_observation(self, state: OverloadedType,
state
The state at the current observation time.

observation_err
observation_error
Given a state :math:`x`, returns the observations error
:math:`y_{i} - \\mathcal{H}_{i}(x)` where :math:`y_{i}` are
the observations at the current observation time and
:math:`\\mathcal{H}_{i}` is the observation operator for the
current observation time.

observation_iprod
observation_covariance
The inner product to calculate the observation error functional
from the observation error :math:`y_{i} - \\mathcal{H}_{i}(x)`.
Can include the error covariance matrix.
Expand All @@ -772,7 +774,9 @@ def set_observation(self, state: OverloadedType,
raise ValueError("Cannot add observations once strong"
" constraint ReducedFunctional instantiated")
self.aaorf._accumulate_functional(
observation_iprod(observation_err(state)))
covariance_norm(observation_error(state),
observation_covariance))

# save the user's state to hand back for beginning of next stage
self.state = state

Expand Down Expand Up @@ -819,9 +823,9 @@ def __init__(self, control: Control,
self._stage_tape = get_working_tape()

def set_observation(self, state: OverloadedType,
observation_err: Callable[[OverloadedType], OverloadedType],
observation_iprod: Callable[[OverloadedType], AdjFloat],
forward_model_iprod: Callable[[OverloadedType], AdjFloat]):
observation_error: Callable[[OverloadedType], OverloadedType],
observation_covariance: Callable[[OverloadedType], AdjFloat],
forward_model_covariance: Callable[[OverloadedType], AdjFloat]):
"""
Record an observation at the time of `state`.

Expand All @@ -831,19 +835,19 @@ def set_observation(self, state: OverloadedType,
state
The state at the current observation time.

observation_err
observation_error
Given a state :math:`x`, returns the observations error
:math:`y_{i} - \\mathcal{H}_{i}(x)` where :math:`y_{i}` are
the observations at the current observation time and
:math:`\\mathcal{H}_{i}` is the observation operator for the
current observation time.

observation_iprod
observation_covariance
The inner product to calculate the observation error functional
from the observation error :math:`y_{i} - \\mathcal{H}_{i}(x)`.
Can include the error covariance matrix.

forward_model_iprod
forward_model_covariance
The inner product to calculate the model error functional from
the model error :math:`x_{i} - \\mathcal{M}_{i}(x_{i-1})`. Can
include the error covariance matrix.
Expand Down Expand Up @@ -889,9 +893,9 @@ def set_observation(self, state: OverloadedType,
'control_name': f"model_err_vec_{self.global_index}_copy"
} if self.global_index else {}

self.model_norm = isolated_rf(
operation=forward_model_iprod,
control=self.model_error.functional,
self.model_norm = CovarianceNormReducedFunctional(
self.model_error.functional,
forward_model_covariance,
**names)

# compose model error reduced functionals to evaluate both together
Expand All @@ -907,17 +911,17 @@ def set_observation(self, state: OverloadedType,
} if self.global_index else {}

self.observation_error = isolated_rf(
operation=observation_err,
operation=observation_error,
control=self.controls[-1],
**names)

# RF to recalculate inner product |H(x_i) - y_i|_R
names = {
'functional_name': "obs_err_vec_{self.global_index}_copy"
'control_name': "obs_err_vec_{self.global_index}_copy"
} if self.global_index else {}
self.observation_norm = isolated_rf(
operation=observation_iprod,
control=self.observation_error.functional,
self.observation_norm = CovarianceNormReducedFunctional(
self.observation_error.functional,
observation_covariance,
**names)

# compose observation reduced functionals to evaluate both together
Expand Down Expand Up @@ -1106,3 +1110,31 @@ def _model_hessian(self, m_dot, options):
model_hessian._ad_convert_type(error_hessian[1],
options=options)
]


def covariance_norm(x, covariance):
if isinstance(covariance, Collection):
covariance, power = covariance
else:
power = None
weight = Constant(1/covariance)
val = assemble(inner(x, weight*x)*dx)
return val if power is None else val**power


class CovarianceNormReducedFunctional(ReducedFunctional):
def __init__(self, x, covariance,
functional_name=None,
control_name=None):
if isinstance(covariance, Collection):
self.covariance, self.power = covariance
else:
self.covariance = covariance
self.power = None
cov_norm = partial(covariance_norm, covariance=covariance)
rf = isolated_rf(cov_norm, x,
functional_name=functional_name,
control_name=control_name)
super().__init__(rf.functional,
rf.controls.delist(rf.controls),
tape=rf.tape)
4 changes: 2 additions & 2 deletions firedrake/ensemblefunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def riesz_representation(self, riesz_map="L2", **kwargs):
}[type(self)]
Vdual = [V.dual() for V in self.local_function_spaces]
riesz = DualType(self.ensemble, Vdual)
for u in riesz.subfunctions:
u.assign(u.riesz_representation(riesz_map=riesz_map, **kwargs))
for uself, uriesz in zip(self.subfunctions, riesz.subfunctions):
uriesz.assign(uself.riesz_representation(riesz_map=riesz_map, **kwargs))
return riesz

@PETSc.Log.EventDecorator()
Expand Down
48 changes: 26 additions & 22 deletions tests/firedrake/regression/test_4dvar_reduced_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,20 @@ def tendency(q, phi):
return qn, qn1, stepper


def prod2(w):
"""generate weighted inner products to pass to FourDVarReducedFunctional"""
def covariance_norm(covariance):
"""generate weighted inner products to pass to FourDVarReducedFunctional.
Use the quadratic norm so Hessian is not linear."""
cov, power = covariance
weight = fd.Constant(1/cov)

def n2(x):
return fd.assemble(fd.inner(x, fd.Constant(w)*x)*fd.dx)**2
return fd.assemble(fd.inner(x, weight*x)*fd.dx)**power
return n2


prodB = prod2(0.1) # background error
prodR = prod2(10.) # observation error
prodQ = prod2(1.0) # model error
B = (fd.Constant(10.), 2) # background error covariance
R = (fd.Constant(0.1), 2) # observation error covariance
Q = (fd.Constant(0.5), 2) # model error covariance


"""Advecting velocity"""
Expand Down Expand Up @@ -182,10 +186,10 @@ def strong_fdvar_pyadjoint(V):
set_working_tape()

# background functional
J = prodB(control - bkg)
J = covariance_norm(B)(control - bkg)

# initial observation functional
J += prodR(obs_errors(0)(control))
J += covariance_norm(R)(obs_errors(0)(control))

qn.assign(control)

Expand All @@ -198,7 +202,7 @@ def strong_fdvar_pyadjoint(V):
qn.assign(qn1)

# observation functional
J += prodR(obs_errors(i)(qn))
J += covariance_norm(R)(obs_errors(i)(qn))

pause_annotation()

Expand Down Expand Up @@ -226,9 +230,9 @@ def strong_fdvar_firedrake(V):

Jhat = FourDVarReducedFunctional(
Control(control),
background_iprod=prodB,
observation_iprod=prodR,
observation_err=obs_errors(0),
background_covariance=B,
observation_covariance=R,
observation_error=obs_errors(0),
weak_constraint=False)

# record observation stages
Expand All @@ -247,7 +251,7 @@ def strong_fdvar_firedrake(V):
# take observation
obs_index = stage.observation_index
stage.set_observation(qn, obs_errors(obs_index),
observation_iprod=prodR)
observation_covariance=R)

pause_annotation()
return Jhat
Expand All @@ -274,10 +278,10 @@ def weak_fdvar_pyadjoint(V):
set_working_tape()

# background error
J = prodB(controls[0] - bkg)
J = covariance_norm(B)(controls[0] - bkg)

# initial observation error
J += prodR(obs_errors(0)(controls[0]))
J += covariance_norm(R)(obs_errors(0)(controls[0]))

# record observation stages
for i in range(1, len(controls)):
Expand All @@ -298,10 +302,10 @@ def weak_fdvar_pyadjoint(V):
controls[i].assign(qn)

# model error for this stage
J += prodQ(qn - controls[i])
J += covariance_norm(Q)(qn - controls[i])

# observation error
J += prodR(obs_errors(i)(controls[i]))
J += covariance_norm(R)(obs_errors(i)(controls[i]))

pause_annotation()

Expand Down Expand Up @@ -340,9 +344,9 @@ def weak_fdvar_firedrake(V, ensemble):

Jhat = FourDVarReducedFunctional(
Control(control),
background_iprod=prodB,
observation_iprod=prodR,
observation_err=obs_errors(0),
background_covariance=B,
observation_covariance=R,
observation_error=obs_errors(0),
weak_constraint=True)

# record observation stages
Expand All @@ -362,8 +366,8 @@ def weak_fdvar_firedrake(V, ensemble):
# take observation
obs_err = obs_errors(stage.observation_index)
stage.set_observation(qn, obs_err,
observation_iprod=prodR,
forward_model_iprod=prodQ)
observation_covariance=R,
forward_model_covariance=Q)

pause_annotation()

Expand Down
Loading