Skip to content

Commit

Permalink
Add Raster.from_xarray() to create raster from a xr.DataArray (#521)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhugonnet authored Mar 19, 2024
1 parent 83d6d10 commit af2171c
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 20 deletions.
97 changes: 81 additions & 16 deletions geoutils/raster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,13 +969,17 @@ def __setitem__(self, index: Mask | NDArrayBool | Any, assign: NDArrayNum | Numb
self._data[:, ind] = assign # type: ignore
return None

def raster_equal(self, other: RasterType) -> bool:
def raster_equal(self, other: RasterType, strict_masked: bool = True, warn_failure_reason: bool = False) -> bool:
"""
Check if two rasters are equal.
This means that are equal:
- The raster's masked array's data (including masked values), mask, fill_value and dtype,
- The raster's transform, crs and nodata values.
:param other: Other raster.
:param strict_masked: Whether to check if masked cells (in .data.mask) have the same value (in .data.data).
:param warn_failure_reason: Whether to warn for the reason of failure if the check does not pass.
"""

# If the mask is just "False", it is equivalent to being equal to an array of False
Expand All @@ -991,8 +995,10 @@ def raster_equal(self, other: RasterType) -> bool:

if not isinstance(other, Raster): # TODO: Possibly add equals to SatelliteImage?
raise NotImplementedError("Equality with other object than Raster not supported by raster_equal.")
return all(
[

if strict_masked:
names = ["data.data", "data.mask", "data.fill_value", "dtype", "transform", "crs", "nodata"]
equalities = [
np.array_equal(self.data.data, other.data.data, equal_nan=True),
np.array_equal(self_mask, other_mask),
self.data.fill_value == other.data.fill_value,
Expand All @@ -1001,7 +1007,26 @@ def raster_equal(self, other: RasterType) -> bool:
self.crs == other.crs,
self.nodata == other.nodata,
]
)
else:
names = ["data", "data.fill_value", "dtype", "transform", "crs", "nodata"]
equalities = [
np.ma.allequal(self.data, other.data),
self.data.fill_value == other.data.fill_value,
self.data.dtype == other.data.dtype,
self.transform == other.transform,
self.crs == other.crs,
self.nodata == other.nodata,
]

complete_equality = all(equalities)

if not complete_equality and warn_failure_reason:
where_fail = np.nonzero(~np.array(equalities))[0]
warnings.warn(
category=UserWarning, message=f"Equality failed for: {', '.join([names[w] for w in where_fail])}."
)

return complete_equality

def _overloading_check(
self: RasterType, other: RasterType | NDArrayNum | Number
Expand Down Expand Up @@ -1336,18 +1361,24 @@ def __ge__(self: RasterType, other: RasterType | NDArrayNum | Number) -> RasterT
return out_mask

@overload
def astype(self, dtype: DTypeLike, convert_nodata: bool = True, *, inplace: Literal[False] = False) -> Raster:
def astype(
self: RasterType, dtype: DTypeLike, convert_nodata: bool = True, *, inplace: Literal[False] = False
) -> RasterType:
...

@overload
def astype(self, dtype: DTypeLike, convert_nodata: bool = True, *, inplace: Literal[True]) -> None:
def astype(self: RasterType, dtype: DTypeLike, convert_nodata: bool = True, *, inplace: Literal[True]) -> None:
...

@overload
def astype(self, dtype: DTypeLike, convert_nodata: bool = True, *, inplace: bool = False) -> Raster | None:
def astype(
self: RasterType, dtype: DTypeLike, convert_nodata: bool = True, *, inplace: bool = False
) -> RasterType | None:
...

def astype(self, dtype: DTypeLike, convert_nodata: bool = True, inplace: bool = False) -> Raster | None:
def astype(
self: RasterType, dtype: DTypeLike, convert_nodata: bool = True, inplace: bool = False
) -> RasterType | None:
"""
Convert data type of the raster.
Expand Down Expand Up @@ -1523,6 +1554,7 @@ def set_nodata(

# Update the nodata value
self._nodata = new_nodata
self.data.fill_value = new_nodata

@property
def data(self) -> MArrayNum:
Expand Down Expand Up @@ -2629,22 +2661,55 @@ def save(

dst.gcps = (rio_gcps, gcps_crs)

@classmethod
def from_xarray(cls: type[RasterType], ds: xr.DataArray, dtype: DTypeLike | None = None) -> RasterType:
"""
Create raster from a xarray.DataArray.
This conversion loads the xarray.DataArray in memory. Use functions of the Xarray accessor directly
to avoid this behaviour.
:param ds: Data array.
:param dtype: Cast the array to a certain dtype.
:return: Raster.
"""

# Define main attributes
crs = ds.rio.crs
transform = ds.rio.transform(recalc=True)
nodata = ds.rio.nodata

# TODO: Add tags and area_or_point with PR #509
raster = cls.from_array(data=ds.data, transform=transform, crs=crs, nodata=nodata)

if dtype is not None:
raster = raster.astype(dtype)

return raster

def to_xarray(self, name: str | None = None) -> xr.DataArray:
"""
Convert raster to a xarray.DataArray.
This method uses rioxarray to generate a DataArray with associated
geo-referencing information.
This converts integer-type rasters into float32.
See the documentation of rioxarray and xarray for more information on
the methods and attributes of the resulting DataArray.
:param name: Name attribute for the data array.
:param name: Name attribute for the DataArray.
:returns: xarray DataArray
:returns: Data array.
"""

ds = rioxarray.open_rasterio(self.to_rio_dataset())
# If type was integer, cast to float to be able to save nodata values in the xarray data array
if np.issubdtype(self.dtypes[0], np.integer):
# Nodata conversion is not needed in this direction (integer towards float), we can maintain the original
updated_raster = self.astype(np.float32, convert_nodata=False)
else:
updated_raster = self

ds = rioxarray.open_rasterio(updated_raster.to_rio_dataset(), masked=True)
# When reading as masked, the nodata is not written to the dataset so we do it manually
ds.rio.set_nodata(self.nodata)

if name is not None:
ds.name = name

Expand Down
52 changes: 48 additions & 4 deletions tests/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def test_to_rio_dataset(self, example: str):
def test_to_xarray(self, example: str):
"""Test the export to a xarray dataset"""

# Open raster and export to rio dataset
# Open raster and export to xarray dataset
rst = gu.Raster(example)
ds = rst.to_xarray()

Expand Down Expand Up @@ -391,9 +391,32 @@ def test_to_xarray(self, example: str):

# Check that the arrays are equal in NaN type
if rst.count > 1:
assert np.array_equal(rst.data.data, ds.data)
assert np.array_equal(rst.get_nanarray(), ds.data.squeeze(), equal_nan=True)
else:
assert np.array_equal(rst.data.data, ds.data.squeeze())
assert np.array_equal(rst.get_nanarray(), ds.data.squeeze(), equal_nan=True)

@pytest.mark.parametrize("example", [landsat_b4_path, aster_dem_path, landsat_rgb_path]) # type: ignore
def test_from_xarray(self, example: str):
"""Test raster creation from a xarray dataset, not fully reversible with to_xarray due to float conversion"""

# Open raster and export to xarray, then import to xarray dataset
rst = gu.Raster(example)
ds = rst.to_xarray()
rst2 = gu.Raster.from_xarray(ds=ds)

# Exporting to a Xarray dataset results in loss of information to float32
# Check that the output equals the input converted to float32 (not fully reversible)
assert rst.astype("float32", convert_nodata=False).raster_equal(rst2, strict_masked=False)

# Test with the dtype argument to convert back to original raster even if integer-type
if np.issubdtype(rst.dtypes[0], np.integer):
# Set an existing nodata value, because all of our integer-type example datasets currently have "None"
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="New nodata value cells already exist.*")
rst.set_nodata(new_nodata=255)
ds = rst.to_xarray()
rst3 = gu.Raster.from_xarray(ds=ds, dtype=rst.dtypes[0])
assert rst3.raster_equal(rst, strict_masked=False)

@pytest.mark.parametrize("nodata_init", [None, "type_default"]) # type: ignore
@pytest.mark.parametrize(
Expand Down Expand Up @@ -2189,6 +2212,7 @@ def test_set_nodata(self, example: str) -> None:

# The nodata value should have been set in the metadata
assert r.nodata == new_nodata
assert r.data.fill_value == new_nodata

# By default, the array should have been updated
if old_nodata is not None:
Expand Down Expand Up @@ -2227,6 +2251,7 @@ def test_set_nodata(self, example: str) -> None:

# The nodata value should have been set in the metadata
assert r.nodata == new_nodata
assert r.data.fill_value == new_nodata

# By default, the array should have been updated similarly for the old nodata
if old_nodata is not None:
Expand Down Expand Up @@ -2269,6 +2294,7 @@ def test_set_nodata(self, example: str) -> None:

# The nodata value should have been set in the metadata
assert r.nodata == new_nodata
assert r.data.fill_value == new_nodata

# Now, the array should not have been updated, so the entire array should be unchanged except for the pixel
assert np.array_equal(r.data.data[~mask_pixel_artificially_set], r_copy.data.data[~mask_pixel_artificially_set])
Expand Down Expand Up @@ -2297,6 +2323,7 @@ def test_set_nodata(self, example: str) -> None:

# The nodata value should have been set in the metadata
assert r.nodata == new_nodata
assert r.data.fill_value == new_nodata

# The array should have been updated
if old_nodata is not None:
Expand All @@ -2323,6 +2350,7 @@ def test_set_nodata(self, example: str) -> None:

# The nodata value should have been set in the metadata
assert r.nodata == new_nodata
assert r.data.fill_value == new_nodata

# The array should not have been updated except for the pixel
assert np.array_equal(r.data.data[~mask_pixel_artificially_set], r_copy.data.data[~mask_pixel_artificially_set])
Expand Down Expand Up @@ -3204,7 +3232,7 @@ def test_reproject(self, mask: gu.Mask) -> None:
match="Reprojecting a mask with a resampling method other than 'nearest', "
"the boolean array will be converted to float during interpolation.",
):
mask.reproject(resampling="bilinear")
mask.reproject(res=50, resampling="bilinear", force_source_nodata=2)

@pytest.mark.parametrize("mask", [mask_landsat_b4, mask_aster_dem, mask_everest]) # type: ignore
def test_crop(self, mask: gu.Mask) -> None:
Expand Down Expand Up @@ -3437,6 +3465,22 @@ def test_raster_equal(self) -> None:
r2.set_nodata(34)
assert not r1.raster_equal(r2)

# Change value of a masked cell
r2 = r1.copy()
r2.data[0, 0] = np.ma.masked
r2.data.data[0, 0] = 0
r3 = r2.copy()
r3.data.data[0, 0] = 10
assert not r2.raster_equal(r3)
assert r2.raster_equal(r3, strict_masked=False)

# Check that a warning is raised with useful information without equality
with pytest.warns(UserWarning, match="Equality failed for: data.data."):
assert not r2.raster_equal(r3, warn_failure_reason=True)

# But no warning is raised for an equality
assert r2.raster_equal(r3, strict_masked=False, warn_failure_reason=True)

def test_equal_georeferenced_grid(self) -> None:
"""
Test that equal for shape, crs and transform work as expected
Expand Down

0 comments on commit af2171c

Please sign in to comment.