From af2171cb838eaa1e01a9a098bc6418e418f1ac0b Mon Sep 17 00:00:00 2001 From: Romain Hugonnet Date: Tue, 19 Mar 2024 13:36:46 -0800 Subject: [PATCH] Add `Raster.from_xarray()` to create raster from a `xr.DataArray` (#521) --- geoutils/raster/raster.py | 97 ++++++++++++++++++++++++++++++++------- tests/test_raster.py | 52 +++++++++++++++++++-- 2 files changed, 129 insertions(+), 20 deletions(-) diff --git a/geoutils/raster/raster.py b/geoutils/raster/raster.py index a886b0a9..1fbfb83b 100644 --- a/geoutils/raster/raster.py +++ b/geoutils/raster/raster.py @@ -969,13 +969,17 @@ def __setitem__(self, index: Mask | NDArrayBool | Any, assign: NDArrayNum | Numb self._data[:, ind] = assign # type: ignore return None - def raster_equal(self, other: RasterType) -> bool: + def raster_equal(self, other: RasterType, strict_masked: bool = True, warn_failure_reason: bool = False) -> bool: """ Check if two rasters are equal. This means that are equal: - The raster's masked array's data (including masked values), mask, fill_value and dtype, - The raster's transform, crs and nodata values. + + :param other: Other raster. + :param strict_masked: Whether to check if masked cells (in .data.mask) have the same value (in .data.data). + :param warn_failure_reason: Whether to warn for the reason of failure if the check does not pass. """ # If the mask is just "False", it is equivalent to being equal to an array of False @@ -991,8 +995,10 @@ def raster_equal(self, other: RasterType) -> bool: if not isinstance(other, Raster): # TODO: Possibly add equals to SatelliteImage? raise NotImplementedError("Equality with other object than Raster not supported by raster_equal.") - return all( - [ + + if strict_masked: + names = ["data.data", "data.mask", "data.fill_value", "dtype", "transform", "crs", "nodata"] + equalities = [ np.array_equal(self.data.data, other.data.data, equal_nan=True), np.array_equal(self_mask, other_mask), self.data.fill_value == other.data.fill_value, @@ -1001,7 +1007,26 @@ def raster_equal(self, other: RasterType) -> bool: self.crs == other.crs, self.nodata == other.nodata, ] - ) + else: + names = ["data", "data.fill_value", "dtype", "transform", "crs", "nodata"] + equalities = [ + np.ma.allequal(self.data, other.data), + self.data.fill_value == other.data.fill_value, + self.data.dtype == other.data.dtype, + self.transform == other.transform, + self.crs == other.crs, + self.nodata == other.nodata, + ] + + complete_equality = all(equalities) + + if not complete_equality and warn_failure_reason: + where_fail = np.nonzero(~np.array(equalities))[0] + warnings.warn( + category=UserWarning, message=f"Equality failed for: {', '.join([names[w] for w in where_fail])}." + ) + + return complete_equality def _overloading_check( self: RasterType, other: RasterType | NDArrayNum | Number @@ -1336,18 +1361,24 @@ def __ge__(self: RasterType, other: RasterType | NDArrayNum | Number) -> RasterT return out_mask @overload - def astype(self, dtype: DTypeLike, convert_nodata: bool = True, *, inplace: Literal[False] = False) -> Raster: + def astype( + self: RasterType, dtype: DTypeLike, convert_nodata: bool = True, *, inplace: Literal[False] = False + ) -> RasterType: ... @overload - def astype(self, dtype: DTypeLike, convert_nodata: bool = True, *, inplace: Literal[True]) -> None: + def astype(self: RasterType, dtype: DTypeLike, convert_nodata: bool = True, *, inplace: Literal[True]) -> None: ... @overload - def astype(self, dtype: DTypeLike, convert_nodata: bool = True, *, inplace: bool = False) -> Raster | None: + def astype( + self: RasterType, dtype: DTypeLike, convert_nodata: bool = True, *, inplace: bool = False + ) -> RasterType | None: ... - def astype(self, dtype: DTypeLike, convert_nodata: bool = True, inplace: bool = False) -> Raster | None: + def astype( + self: RasterType, dtype: DTypeLike, convert_nodata: bool = True, inplace: bool = False + ) -> RasterType | None: """ Convert data type of the raster. @@ -1523,6 +1554,7 @@ def set_nodata( # Update the nodata value self._nodata = new_nodata + self.data.fill_value = new_nodata @property def data(self) -> MArrayNum: @@ -2629,22 +2661,55 @@ def save( dst.gcps = (rio_gcps, gcps_crs) + @classmethod + def from_xarray(cls: type[RasterType], ds: xr.DataArray, dtype: DTypeLike | None = None) -> RasterType: + """ + Create raster from a xarray.DataArray. + + This conversion loads the xarray.DataArray in memory. Use functions of the Xarray accessor directly + to avoid this behaviour. + + :param ds: Data array. + :param dtype: Cast the array to a certain dtype. + + :return: Raster. + """ + + # Define main attributes + crs = ds.rio.crs + transform = ds.rio.transform(recalc=True) + nodata = ds.rio.nodata + + # TODO: Add tags and area_or_point with PR #509 + raster = cls.from_array(data=ds.data, transform=transform, crs=crs, nodata=nodata) + + if dtype is not None: + raster = raster.astype(dtype) + + return raster + def to_xarray(self, name: str | None = None) -> xr.DataArray: """ Convert raster to a xarray.DataArray. - This method uses rioxarray to generate a DataArray with associated - geo-referencing information. + This converts integer-type rasters into float32. - See the documentation of rioxarray and xarray for more information on - the methods and attributes of the resulting DataArray. + :param name: Name attribute for the data array. - :param name: Name attribute for the DataArray. - - :returns: xarray DataArray + :returns: Data array. """ - ds = rioxarray.open_rasterio(self.to_rio_dataset()) + # If type was integer, cast to float to be able to save nodata values in the xarray data array + if np.issubdtype(self.dtypes[0], np.integer): + # Nodata conversion is not needed in this direction (integer towards float), we can maintain the original + updated_raster = self.astype(np.float32, convert_nodata=False) + else: + updated_raster = self + + ds = rioxarray.open_rasterio(updated_raster.to_rio_dataset(), masked=True) + # When reading as masked, the nodata is not written to the dataset so we do it manually + ds.rio.set_nodata(self.nodata) + if name is not None: ds.name = name diff --git a/tests/test_raster.py b/tests/test_raster.py index 97a68d86..cb88c2d6 100644 --- a/tests/test_raster.py +++ b/tests/test_raster.py @@ -363,7 +363,7 @@ def test_to_rio_dataset(self, example: str): def test_to_xarray(self, example: str): """Test the export to a xarray dataset""" - # Open raster and export to rio dataset + # Open raster and export to xarray dataset rst = gu.Raster(example) ds = rst.to_xarray() @@ -391,9 +391,32 @@ def test_to_xarray(self, example: str): # Check that the arrays are equal in NaN type if rst.count > 1: - assert np.array_equal(rst.data.data, ds.data) + assert np.array_equal(rst.get_nanarray(), ds.data.squeeze(), equal_nan=True) else: - assert np.array_equal(rst.data.data, ds.data.squeeze()) + assert np.array_equal(rst.get_nanarray(), ds.data.squeeze(), equal_nan=True) + + @pytest.mark.parametrize("example", [landsat_b4_path, aster_dem_path, landsat_rgb_path]) # type: ignore + def test_from_xarray(self, example: str): + """Test raster creation from a xarray dataset, not fully reversible with to_xarray due to float conversion""" + + # Open raster and export to xarray, then import to xarray dataset + rst = gu.Raster(example) + ds = rst.to_xarray() + rst2 = gu.Raster.from_xarray(ds=ds) + + # Exporting to a Xarray dataset results in loss of information to float32 + # Check that the output equals the input converted to float32 (not fully reversible) + assert rst.astype("float32", convert_nodata=False).raster_equal(rst2, strict_masked=False) + + # Test with the dtype argument to convert back to original raster even if integer-type + if np.issubdtype(rst.dtypes[0], np.integer): + # Set an existing nodata value, because all of our integer-type example datasets currently have "None" + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="New nodata value cells already exist.*") + rst.set_nodata(new_nodata=255) + ds = rst.to_xarray() + rst3 = gu.Raster.from_xarray(ds=ds, dtype=rst.dtypes[0]) + assert rst3.raster_equal(rst, strict_masked=False) @pytest.mark.parametrize("nodata_init", [None, "type_default"]) # type: ignore @pytest.mark.parametrize( @@ -2189,6 +2212,7 @@ def test_set_nodata(self, example: str) -> None: # The nodata value should have been set in the metadata assert r.nodata == new_nodata + assert r.data.fill_value == new_nodata # By default, the array should have been updated if old_nodata is not None: @@ -2227,6 +2251,7 @@ def test_set_nodata(self, example: str) -> None: # The nodata value should have been set in the metadata assert r.nodata == new_nodata + assert r.data.fill_value == new_nodata # By default, the array should have been updated similarly for the old nodata if old_nodata is not None: @@ -2269,6 +2294,7 @@ def test_set_nodata(self, example: str) -> None: # The nodata value should have been set in the metadata assert r.nodata == new_nodata + assert r.data.fill_value == new_nodata # Now, the array should not have been updated, so the entire array should be unchanged except for the pixel assert np.array_equal(r.data.data[~mask_pixel_artificially_set], r_copy.data.data[~mask_pixel_artificially_set]) @@ -2297,6 +2323,7 @@ def test_set_nodata(self, example: str) -> None: # The nodata value should have been set in the metadata assert r.nodata == new_nodata + assert r.data.fill_value == new_nodata # The array should have been updated if old_nodata is not None: @@ -2323,6 +2350,7 @@ def test_set_nodata(self, example: str) -> None: # The nodata value should have been set in the metadata assert r.nodata == new_nodata + assert r.data.fill_value == new_nodata # The array should not have been updated except for the pixel assert np.array_equal(r.data.data[~mask_pixel_artificially_set], r_copy.data.data[~mask_pixel_artificially_set]) @@ -3204,7 +3232,7 @@ def test_reproject(self, mask: gu.Mask) -> None: match="Reprojecting a mask with a resampling method other than 'nearest', " "the boolean array will be converted to float during interpolation.", ): - mask.reproject(resampling="bilinear") + mask.reproject(res=50, resampling="bilinear", force_source_nodata=2) @pytest.mark.parametrize("mask", [mask_landsat_b4, mask_aster_dem, mask_everest]) # type: ignore def test_crop(self, mask: gu.Mask) -> None: @@ -3437,6 +3465,22 @@ def test_raster_equal(self) -> None: r2.set_nodata(34) assert not r1.raster_equal(r2) + # Change value of a masked cell + r2 = r1.copy() + r2.data[0, 0] = np.ma.masked + r2.data.data[0, 0] = 0 + r3 = r2.copy() + r3.data.data[0, 0] = 10 + assert not r2.raster_equal(r3) + assert r2.raster_equal(r3, strict_masked=False) + + # Check that a warning is raised with useful information without equality + with pytest.warns(UserWarning, match="Equality failed for: data.data."): + assert not r2.raster_equal(r3, warn_failure_reason=True) + + # But no warning is raised for an equality + assert r2.raster_equal(r3, strict_masked=False, warn_failure_reason=True) + def test_equal_georeferenced_grid(self) -> None: """ Test that equal for shape, crs and transform work as expected