Skip to content

Commit

Permalink
Fix error when plotting empty transects
Browse files Browse the repository at this point in the history
Fixes #119
  • Loading branch information
mx-moth committed Jan 11, 2024
1 parent 395514d commit 40796d7
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 20 deletions.
30 changes: 24 additions & 6 deletions src/emsarray/conventions/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,21 +268,19 @@ def bind(self) -> None:
"cannot assign a new convention.")
state.bind_convention(self)

@abc.abstractmethod
def _get_data_array(self, data_array: DataArrayOrName) -> xarray.DataArray:
"""
Utility to help get a data array for this dataset.
If a string is passed in, the matching data array is fetched from the dataset.
If a data array is passed in, it is inspected to ensure the dimensions match
If a data array is passed in,
it is inspected to ensure the surface dimensions align
before being returned as-is.
This is useful for methods that support being passed either
the name of a data array or a data array instance.
"""
if isinstance(data_array, xarray.DataArray):
utils.check_data_array_dimensions_match(self.dataset, data_array)
return data_array
else:
return self.dataset[data_array]
pass

@cached_property
def time_coordinate(self) -> xarray.DataArray:
Expand Down Expand Up @@ -1011,6 +1009,9 @@ def animate_on_figure(
if coordinate is None:
# Assume the user wants to plot along the time axis by default.
coordinate = self.get_time_name()
if isinstance(coordinate, xarray.DataArray):
utils.check_data_array_dimensions_match(
self.dataset, coordinate, coordinate.dims)

coordinate = self._get_data_array(coordinate)

Expand Down Expand Up @@ -1787,6 +1788,23 @@ def get_grid_kind(self, data_array: xarray.DataArray) -> GridKind:
return kind
raise ValueError("Unknown grid kind")

def _get_data_array(self, data_array: DataArrayOrName) -> xarray.DataArray:
if isinstance(data_array, xarray.DataArray):
grid_kind = self.get_grid_kind(data_array)
grid_dimensions = self.grid_dimensions[grid_kind]
for dimension in grid_dimensions:
# The data array already has matching dimension names
# as we found the grid kind using `Convention.get_grid_kind()`.
if self.dataset.sizes[dimension] != data_array.sizes[dimension]:
raise ValueError(
f"Mismatched dimension {dimension!r}, "
"dataset has size {self.dataset.sizes[dimension]} but "
"data array has size {data_array.sizes[dimension]}!"
)
return data_array
else:
return self.dataset[data_array]

@abc.abstractmethod
def unpack_index(self, index: Index) -> Tuple[GridKind, Sequence[int]]:
"""Convert a native index in to a grid kind and dimension indices.
Expand Down
26 changes: 20 additions & 6 deletions src/emsarray/transect.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,17 @@ def transect_dataset(self) -> xarray.Dataset:
dims=(depth_dimension, 'bounds'),
)
distance_bounds = xarray.DataArray(
data=[
[segment.start_distance, segment.end_distance]
for segment in self.segments
],
data=numpy.fromiter(
(
[segment.start_distance, segment.end_distance]
for segment in self.segments
),
# Be explicit here, to handle the case when len(self.segments) == 0.
# This happens when the transect line does not intersect the dataset.
# This will result in an empty transect plot.
count=len(self.segments),
dtype=numpy.dtype((float, 2)),
),
dims=('index', 'bounds'),
attrs={
'long_name': 'Distance along transect',
Expand Down Expand Up @@ -762,8 +769,15 @@ def _plot_on_figure(

cmap = colormaps[cmap].copy()
cmap.set_bad(ocean_floor_colour)
collection = self.make_poly_collection(
cmap=cmap, clim=(numpy.nanmin(data_array), numpy.nanmax(data_array)))

if data_array.size != 0:
clim = (numpy.nanmin(data_array), numpy.nanmax(data_array))
else:
# An empty data array happens when the transect line does not
# intersect the dataset geometry.
clim = None

collection = self.make_poly_collection(cmap=cmap, clim=clim, edgecolor='face')
axes.add_collection(collection)

if bathymetry is not None:
Expand Down
32 changes: 24 additions & 8 deletions src/emsarray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,12 @@ def dimensions_from_coords(
return dimensions


def check_data_array_dimensions_match(dataset: xarray.Dataset, data_array: xarray.DataArray) -> None:
def check_data_array_dimensions_match(
dataset: xarray.Dataset,
data_array: xarray.DataArray,
*,
dimensions: Optional[Sequence[Hashable]] = None,
) -> None:
"""
Check that the dimensions of a :class:`xarray.DataArray`
match the dimensions of a :class:`xarray.Dataset`.
Expand All @@ -438,27 +443,38 @@ def check_data_array_dimensions_match(dataset: xarray.Dataset, data_array: xarra
Parameters
----------
dataset
dataset : xarray.Dataset
The dataset used as a reference
data_array
data_array : xarray.DataArray
The data array to check the dimensions of
dimensions: list of Hashable, optional
The dimension names to check for equal sizes.
Optional, defaults to checking all dimensions on the data array.
Raises
------
ValueError
Raised if the dimensions do not match
"""
for dimension, data_array_size in zip(data_array.dims, data_array.shape):
if dimension not in dataset.dims:
if dimensions is None:
dimensions = data_array.dims

for dimension in dimensions:
if dimension not in dataset.dims and dimension not in data_array.dims:
raise ValueError(
f"Data array has unknown dimension {dimension} of size {data_array_size}"
f"Dimension {dimension!r} not present on either dataset or data array"
)

elif dimension not in dataset.dims:
raise ValueError(f"Dataset does not have dimension {dimension!r}")
elif dimension not in data_array.dims:
raise ValueError(f"Data array does not have dimension {dimension!r}")
dataset_size = dataset.sizes[dimension]
data_array_size = data_array.sizes[dimension]

if data_array_size != dataset_size:
raise ValueError(
"Dimension mismatch between dataset and data array: "
f"Dataset dimension {dimension} has size {dataset_size}, "
f"Dataset dimension {dimension!r} has size {dataset_size}, "
f"data array has size {data_array_size}"
)

Expand Down
6 changes: 6 additions & 0 deletions tests/conventions/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ class SimpleConvention(Convention[SimpleGridKind, SimpleGridIndex]):
def check_dataset(cls, dataset: xarray.Dataset) -> Optional[int]:
return None

def _get_data_array(self, data_array_or_name) -> xarray.DataArray:
if isinstance(data_array_or_name, str):
return self.dataset[data_array_or_name]
else:
return data_array_or_name

@cached_property
def shape(self) -> Tuple[int, int]:
y, x = map(int, self.dataset['botz'].shape)
Expand Down
42 changes: 42 additions & 0 deletions tests/test_transect.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,45 @@ def test_plot(

matplotlib.pyplot.savefig(tmp_path / 'plot.png')
logger.info("Saved plot to %r", tmp_path / 'plot.png')


@pytest.mark.matplotlib(mock_coast=True)
@pytest.mark.tutorial
def test_plot_no_intersection(
datasets: pathlib.Path,
tmp_path: pathlib.Path,
):
"""
Transects that do not intersect the dataset geometry need special handling.
This should produce an empty transect plot, which is better than raising an error.
"""
dataset = emsarray.tutorial.open_dataset('gbr4')
temp = dataset['temp'].copy()
temp = temp.isel(time=-1)

# This line goes through the Bass Strait, no where near the GBR.
# Someone picked the wrong dataset...
line = shapely.LineString([
[142.097168, -39.206719],
[145.393066, -39.3088],
[149.798584, -39.172659],
])
emsarray.transect.plot(
dataset, line, temp,
bathymetry=dataset['botz'])

figure = matplotlib.pyplot.gcf()
axes = figure.axes[0]
# This is assembled from the variable long_name and the time coordinate
assert axes.get_title() == 'Temperature\n2022-05-11T14:00'
# This is the long_name of the depth coordinate
assert axes.get_ylabel() == 'Z coordinate'
# This is made up
assert axes.get_xlabel() == 'Distance along transect'

colorbar = figure.axes[-1]
# This is the variable units
assert colorbar.get_ylabel() == 'degrees C'

matplotlib.pyplot.savefig(tmp_path / 'plot.png')
logger.info("Saved plot to %r", tmp_path / 'plot.png')

0 comments on commit 40796d7

Please sign in to comment.