Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix missing masks for nearest scheme #369

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 96 additions & 56 deletions esmf_regrid/esmf_regridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
]


def _get_regrid_weights_dict(src_field, tgt_field, regrid_method):
def _get_regrid_weights_dict(src_field, tgt_field, regrid_method, ignore_mask=False):
# The value, in array form, that ESMF should treat as an affirmative mask.
expected_mask = np.array([True])
mask_values = [] if ignore_mask else [1]
expected_mask = np.array(mask_values, dtype=np.int32)
regridder = esmpy.Regrid(
src_field,
tgt_field,
Expand Down Expand Up @@ -55,6 +56,31 @@ def _weights_dict_to_sparse_array(weights, shape, index_offsets):
return matrix


def _compute_weights_matrix(src, tgt, method, ignore_mask=False):
weights_dict = _get_regrid_weights_dict(
src.make_esmf_field(),
tgt.make_esmf_field(),
regrid_method=method.value,
ignore_mask=ignore_mask,
)
weight_matrix = _weights_dict_to_sparse_array(
weights_dict,
(tgt._refined_size, src._refined_size),
(tgt.index_offset, src.index_offset),
)
if isinstance(tgt, RefinedGridInfo):
# At this point, the weight matrix represents more target points than
# tgt respresents. In order to collapse these points, we collapse the
# weights matrix by the appropriate matrix multiplication.
weight_matrix = tgt._collapse_weights(is_tgt=True) @ weight_matrix
if isinstance(src, RefinedGridInfo):
# At this point, the weight matrix represents more source points than
# src respresents. In order to collapse these points, we collapse the
# weights matrix by the appropriate matrix multiplication.
weight_matrix = weight_matrix @ src._collapse_weights(is_tgt=False)
return weight_matrix


class Regridder:
"""Regridder for directly interfacing with :mod:`esmpy`."""

Expand All @@ -81,10 +107,13 @@ def __init__(
shape is compatible with ``tgt``.
method : :class:`Constants.Method`
The method to be used to calculate weights.
precomputed_weights : :class:`scipy.sparse.spmatrix`, optional
precomputed_weights : :class:`scipy.sparse.spmatrix` or tuple of :class:`scipy.sparse.spmatrix`, optional
If ``None``, :mod:`esmpy` will be used to
calculate regridding weights. Otherwise, :mod:`esmpy` will be bypassed
and ``precomputed_weights`` will be used as the regridding weights.
If ``method`` is :obj:`Constants.Method.NEAREST`, a tuple with two
sparse matrices can be provided and these will be used as the
regridding weights for the data and mask respectively.
"""
self.src = src
self.tgt = tgt
Expand All @@ -94,45 +123,43 @@ def __init__(
self.esmf_regrid_version = esmf_regrid.__version__
if precomputed_weights is None:
self.esmf_version = esmpy.__version__
weights_dict = _get_regrid_weights_dict(
src.make_esmf_field(),
tgt.make_esmf_field(),
regrid_method=method.value,
)
self.weight_matrix = _weights_dict_to_sparse_array(
weights_dict,
(self.tgt._refined_size, self.src._refined_size),
(self.tgt.index_offset, self.src.index_offset),
)
if isinstance(tgt, RefinedGridInfo):
# At this point, the weight matrix represents more target points than
# tgt respresents. In order to collapse these points, we collapse the
# weights matrix by the appropriate matrix multiplication.
self.weight_matrix = (
tgt._collapse_weights(is_tgt=True) @ self.weight_matrix
)
if isinstance(src, RefinedGridInfo):
# At this point, the weight matrix represents more source points than
# src respresents. In order to collapse these points, we collapse the
# weights matrix by the appropriate matrix multiplication.
self.weight_matrix = self.weight_matrix @ src._collapse_weights(
is_tgt=False
)
self.weight_matrix = _compute_weights_matrix(src, tgt, method)
if method is Constants.Method.NEAREST and (
src.mask is not None or tgt.mask is not None
):
# Nearest regridding will not create weights for masked points,
# therefore we need to ignore the masked points when computing
# the regridding weights for the mask.
self.mask_weight_matrix = _compute_weights_matrix(
src, tgt, method, ignore_mask=True
).astype(bool)
else:
self.mask_weight_matrix = None
else:
if not scipy.sparse.isspmatrix(precomputed_weights):
raise ValueError(
"Precomputed weights must be given as a sparse matrix."
)
if precomputed_weights.shape != (self.tgt.size, self.src.size):
msg = "Expected precomputed weights to have shape {}, got shape {} instead."
raise ValueError(
msg.format(
(self.tgt.size, self.src.size),
precomputed_weights.shape,
if not isinstance(precomputed_weights, (tuple, list)):
precomputed_weights = (precomputed_weights,)
for weight_matrix in precomputed_weights:
if not scipy.sparse.isspmatrix(weight_matrix):
raise ValueError(
"Precomputed weights must be given as a sparse matrix."
)
if weight_matrix.shape != (self.tgt.size, self.src.size):
msg = "Expected precomputed weights to have shape {}, got shape {} instead."
raise ValueError(
msg.format(
(self.tgt.size, self.src.size),
weight_matrix.shape,
)
)
)
self.esmf_version = None
self.weight_matrix = precomputed_weights
self.weight_matrix = precomputed_weights[0]
if (
self.method == Constants.Method.NEAREST
and len(precomputed_weights) == 2
):
self.mask_weight_matrix = precomputed_weights[1]
else:
self.mask_weight_matrix = None

def _out_dtype(self, in_dtype):
"""Return the expected output dtype for a given input dtype."""
Expand Down Expand Up @@ -178,24 +205,37 @@ def regrid(self, src_array, norm_type=Constants.NormType.FRACAREA, mdtol=1):
f"got an array with shape ending in {main_shape}."
)
extra_shape = array_shape[: -self.src.dims]
extra_size = max(1, np.prod(extra_shape))
src_inverted_mask = self.src._array_to_matrix(~ma.getmaskarray(src_array))
weight_sums = self.weight_matrix @ src_inverted_mask
out_dtype = self._out_dtype(src_array.dtype)
# Set the minimum mdtol to be slightly higher than 0 to account for rounding
# errors.
mdtol = max(mdtol, 1e-8)
tgt_mask = weight_sums > 1 - mdtol
masked_weight_sums = weight_sums * tgt_mask
normalisations = np.ones([self.tgt.size, extra_size], dtype=out_dtype)
if norm_type == Constants.NormType.FRACAREA:
normalisations[tgt_mask] /= masked_weight_sums[tgt_mask]
elif norm_type == Constants.NormType.DSTAREA:
pass
normalisations = ma.array(normalisations, mask=np.logical_not(tgt_mask))

flat_src = self.src._array_to_matrix(ma.filled(src_array, 0.0))
flat_tgt = self.weight_matrix @ flat_src
flat_tgt = flat_tgt * normalisations

# Handle normalization and masking.
flat_src_mask = self.src._array_to_matrix(ma.getmaskarray(src_array))
if self.method == Constants.Method.NEAREST:
# Normalization is not required in this case because no destination
# point will receive input from more than one source point.
if self.mask_weight_matrix is None:
flat_tgt_mask = self.weight_matrix.astype(bool) @ flat_src_mask
else:
flat_tgt_mask = self.mask_weight_matrix @ flat_src_mask
if self.tgt.mask is not None:
flat_tgt_mask |= self.tgt._array_to_matrix(self.tgt.mask)
flat_tgt = ma.masked_array(flat_tgt, flat_tgt_mask)
else:
extra_size = max(1, np.prod(extra_shape))
weight_sums = self.weight_matrix @ ~flat_src_mask
out_dtype = self._out_dtype(src_array.dtype)
# Set the minimum mdtol to be slightly higher than 0 to account for rounding
# errors.
mdtol = max(mdtol, 1e-8)
tgt_mask = weight_sums > 1 - mdtol
masked_weight_sums = weight_sums * tgt_mask
normalisations = np.ones([self.tgt.size, extra_size], dtype=out_dtype)
if norm_type == Constants.NormType.FRACAREA:
normalisations[tgt_mask] /= masked_weight_sums[tgt_mask]
elif norm_type == Constants.NormType.DSTAREA:
pass
normalisations = ma.array(normalisations, mask=np.logical_not(tgt_mask))
flat_tgt = flat_tgt * normalisations

tgt_array = self.tgt._matrix_to_array(flat_tgt, extra_shape)
return tgt_array
143 changes: 143 additions & 0 deletions esmf_regrid/tests/unit/esmf_regridder/test_Regridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,90 @@ def test_Regridder_init_fail():
_ = Regridder(src_grid, tgt_grid, method="other")


@pytest.fixture
def nearest_weights():
"""Weights matrix for testing the nearest neighbour regridder."""
weight_list = np.ones(3)
src_idx = np.ones(3)
tgt_idx = np.arange(3)

shape = (4, 2)

weights = scipy.sparse.csr_matrix((weight_list, (tgt_idx, src_idx)), shape=shape)
return weights


@pytest.fixture
def nearest_mask_weights():
"""Mask weights matrix for testing the nearest neighbour regridder."""
weight_list = np.ones(4, dtype=bool)
src_idx = np.array([0, 0, 1, 1])
tgt_idx = np.arange(4)

shape = (4, 2)

weights = scipy.sparse.csr_matrix((weight_list, (tgt_idx, src_idx)), shape=shape)
return weights


@pytest.fixture
def nearest_grids():
"""Source and target grids for testing the nearest neighbour regridder."""

# The following ASCII visualisation describes the source and target grid
# indices and the mask (m) which ESMF assigns to their cells when
# computing nearest neighbour weights. Masked cells are not used in the
# resulting weights matrix, but should be used in the resulting mask
# weights matrix.
#
# 20 +-------+ +---+---+
# | 1 | | 1 | m |
# 10 +-------+ +---+---+
# | m | | 0 | 2 |
# 0 +-------+ +---+---+
# 0 20 0 10 20
def _get_points(bounds):
points = (bounds[:-1] + bounds[1:]) / 2
return points

lon_bounds = np.array([0, 10, 20])
lat_bounds = np.array([0, 20])
lon, lat = _get_points(lon_bounds), _get_points(lat_bounds)
src_mask = np.zeros((len(lat), len(lon)), dtype=bool)
src_mask[0, 0] = True
src_grid = GridInfo(lon, lat, lon_bounds, lat_bounds, center=True, mask=src_mask)

lon_bounds = np.array([0, 10, 20])
lat_bounds = np.array([0, 10, 20])
lon, lat = _get_points(lon_bounds), _get_points(lat_bounds)
tgt_mask = np.zeros((len(lat), len(lon)), dtype=bool)
tgt_mask[1, 1] = True
tgt_grid = GridInfo(lon, lat, lon_bounds, lat_bounds, center=True, mask=tgt_mask)
return src_grid, tgt_grid


def test_Regridder_init_nearest_masked(
nearest_grids,
nearest_weights,
nearest_mask_weights,
):
"""Test :meth:`~esmf_regrid.esmf_regridder.Regridder.__init__`.

Check that the mask on the source and target array is respected for the
weights calculation and ignored for the mask weights calculation when
using nearest neighbour regridding.
"""
src_grid, tgt_grid = nearest_grids

rg = Regridder(src_grid, tgt_grid, method=Constants.Method.NEAREST)

result = rg.weight_matrix
assert np.allclose(result.toarray(), nearest_weights.toarray())

result = rg.mask_weight_matrix
assert np.allclose(result.toarray(), nearest_mask_weights.toarray())


def test_Regridder_regrid():
"""Basic test for :meth:`~esmf_regrid.esmf_regridder.Regridder.regrid`."""
lon, lat, lon_bounds, lat_bounds = make_grid_args(2, 3)
Expand Down Expand Up @@ -154,6 +238,65 @@ def _give_extra_dims(array):
_ = rg.regrid(src_masked, norm_type="INVALID")


def test_Regridder_regrid_nearest_masked(
nearest_grids,
nearest_weights,
nearest_mask_weights,
):
"""Test for :meth:`~esmf_regrid.esmf_regridder.Regridder.regrid`.

Check that the mask is computed correctly from the source array and that
the target mask is applied.
"""
src_grid, tgt_grid = nearest_grids

# Set up the regridder with precomputed weights.
rg = Regridder(
src_grid,
tgt_grid,
method=Constants.Method.NEAREST,
precomputed_weights=(nearest_weights, nearest_mask_weights),
)

src_array = np.arange(4).reshape((2, 1, 2))
src_masked = ma.array(
src_array,
mask=np.array(
[
[[True, False]],
[[True, False]],
]
),
)

result = rg.regrid(src_masked)
expected = ma.masked_array(
[
[
[0, 1.0],
[0, 0],
],
[
[0, 3.0],
[0, 0],
],
],
mask=np.array(
[
[
[True, False],
[True, True],
],
[
[True, False],
[True, True],
],
]
),
)
assert ma.allclose(result, expected)


def test_Regridder_init_small():
"""
Simplified test for :meth:`~esmf_regrid.esmf_regridder.Regridder.regrid`.
Expand Down