diff --git a/docs/api.rst b/docs/api.rst index 45b51bea5..ff1ee247b 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -595,6 +595,7 @@ Clip :template: autosummary/accessor_method.rst :toctree: _generated + DataArray.raster.clip DataArray.raster.clip_bbox DataArray.raster.clip_mask DataArray.raster.clip_geom diff --git a/hydromt/gis_utils.py b/hydromt/gis_utils.py index 1ed3e5ac5..9800737a7 100644 --- a/hydromt/gis_utils.py +++ b/hydromt/gis_utils.py @@ -365,7 +365,7 @@ def affine_to_coords(transform, shape, x_dim="x", y_dim="y"): if not isinstance(transform, Affine): transform = Affine(*transform) height, width = shape - if transform.b == 0: + if transform.b == transform.d == 0: x_coords, _ = transform * (np.arange(width) + 0.5, np.zeros(width) + 0.5) _, y_coords = transform * (np.zeros(height) + 0.5, np.arange(height) + 0.5) coords = { diff --git a/hydromt/raster.py b/hydromt/raster.py index 11bff366a..180120b73 100644 --- a/hydromt/raster.py +++ b/hydromt/raster.py @@ -82,6 +82,7 @@ def full_like( dims=other.dims, ) da.raster.set_attrs(**other.raster.attrs) + da.raster._transform = other.raster.transform return da @@ -185,6 +186,8 @@ def full_from_transform( attrs = attrs or {} if len(shape) not in [2, 3]: raise ValueError("Only 2D and 3D data arrays supported.") + if not isinstance(transform, Affine): + transform = Affine(*transform) coords = gis_utils.affine_to_coords(transform, shape[-2:], x_dim="x", y_dim="y") dims = ("y", "x") if len(shape) == 3: @@ -201,6 +204,7 @@ def full_from_transform( shape=shape, dims=dims, ) + da.raster._transform = transform return da @@ -223,6 +227,7 @@ class XGeoBase(object): def __init__(self, xarray_obj: xr.DataArray | xr.Dataset) -> None: """Initialize new object based on the xarray object provided.""" self._obj = xarray_obj + self._crs = None # create new coordinate with attributes in which to save x_dim, y_dim and crs. # other spatial properties are always calculated on the fly to ensure # consistency with data @@ -262,13 +267,13 @@ def crs(self) -> CRS: """Return horizontal Coordinate Reference System.""" # return horizontal crs by default to avoid errors downstream # with reproject / rasterize etc. - if "crs_wkt" not in self.attrs: - self.set_crs() - if "crs_wkt" in self.attrs: - crs = pyproj.CRS.from_user_input(self.attrs["crs_wkt"]) - return crs + if self._crs is not None: + crs = self._crs + else: + crs = self.set_crs() + return crs - def set_crs(self, input_crs=None): + def set_crs(self, input_crs=None, write_crs=True) -> CRS: """Set the Coordinate Reference System. Arguments @@ -276,39 +281,52 @@ def set_crs(self, input_crs=None): input_crs: int, dict, or str, optional Coordinate Reference System. Accepts EPSG codes (int or str) and proj (str or dict) + write_crs: bool, optional + If True (default), write CRS to attributes. """ crs_names = ["crs_wkt", "crs", "epsg"] names = list(self._obj.coords.keys()) if isinstance(self._obj, xr.Dataset): names = names + list(self._obj.data_vars.keys()) # user defined - if input_crs is not None: + if isinstance(input_crs, (int, str, dict)): input_crs = pyproj.CRS.from_user_input(input_crs) # look in grid_mapping and data variable attributes + elif input_crs is not None and not isinstance(input_crs, CRS): + raise ValueError(f"Invalid CRS type: {type(input_crs)}") else: + crs = None for name in crs_names: - # check default > GEO_MAP_COORDS attrs - crs = self._obj.coords[GEO_MAP_COORD].attrs.get(name, None) - if crs is None: # global attrs - crs = self._obj.attrs.pop(name, None) - for var in names: # data var and coords attrs - if name in self._obj[var].attrs: - crs = self._obj[var].attrs.pop(name) - break - if crs is not None: - # avoid Warning 1: +init=epsg:XXXX syntax is deprecated - crs = crs.removeprefix("+init=") if isinstance(crs, str) else crs - try: - input_crs = pyproj.CRS.from_user_input(crs) - break - except RuntimeError: - pass + # check default > GEO_MAP_COORDS attrs, then global attrs + if name in self.attrs: + crs = self.attrs.get(name) + break + if name in self._obj.attrs: + crs = self._obj.attrs.pop(name) + break + if crs is None: # check data var and coords attrs + for var in names: + for name in crs_names: + if name in self._obj[var].attrs: + crs = self._obj[var].attrs.pop(name) + break + if crs is not None: + # avoid Warning 1: +init=epsg:XXXX syntax is deprecated + if isinstance(crs, str): + crs = crs.removeprefix("+init=") + try: + input_crs = pyproj.CRS.from_user_input(crs) + except RuntimeError: + pass if input_crs is not None: - grid_map_attrs = input_crs.to_cf() - crs_wkt = input_crs.to_wkt() - grid_map_attrs["spatial_ref"] = crs_wkt - grid_map_attrs["crs_wkt"] = crs_wkt - self.set_attrs(**grid_map_attrs) + if write_crs: + grid_map_attrs = input_crs.to_cf() + crs_wkt = input_crs.to_wkt() + grid_map_attrs["spatial_ref"] = crs_wkt + grid_map_attrs["crs_wkt"] = crs_wkt + self.set_attrs(**grid_map_attrs) + self._crs = input_crs + return input_crs class XRasterBase(XGeoBase): @@ -318,6 +336,10 @@ class XRasterBase(XGeoBase): def __init__(self, xarray_obj): """Initialize new object based on the xarray object provided.""" super(XRasterBase, self).__init__(xarray_obj) + self._res = None + self._rotation = None + self._origin = None + self._transform = None @property def x_dim(self) -> str: @@ -472,11 +494,14 @@ def height(self) -> int: @property def transform(self) -> Affine: """Return the affine transform of the object.""" + if self._transform is not None: + return self._transform transform = ( Affine.translation(*self.origin) * Affine.rotation(self.rotation) * Affine.scale(*self.res) ) + self._transform = transform return transform @property @@ -520,21 +545,35 @@ def res(self) -> tuple[float, float]: NOTE: rotated rasters with a negative dx are not supported. """ - xs, ys = self.xcoords.data, self.ycoords.data - dx, dy = 0, 0 - if xs.ndim == 1: - dx = xs[1] - xs[0] - dy = ys[1] - ys[0] - elif xs.ndim == 2: - ddx0 = xs[1, 0] - xs[0, 0] - ddy0 = ys[1, 0] - ys[0, 0] - ddx1 = xs[0, 1] - xs[0, 0] - ddy1 = ys[0, 1] - ys[0, 0] - dx = math.hypot(ddx1, ddy1) # always positive! - dy = math.hypot(ddx0, ddy0) + # cached transform is leading, then cached res, then coords + if self._transform is not None: + if self._transform.b == self._transform.d == 0: + dx, dy = self._transform.a, self._transform.e + else: + xy0 = self._transform * (0, 0) + x1y = self._transform * (1, 0) + xy1 = self._transform * (0, 1) + ddx0, ddy0 = xy1[0] - xy0[0], xy1[1] - xy0[1] + dx = math.hypot(x1y[0] - xy0[0], x1y[1] - xy0[1]) + dy = math.hypot(xy1[0] - xy0[0], xy1[1] - xy0[1]) + elif self._res is not None: + return self._res + else: + xs, ys = self.xcoords.data, self.ycoords.data + if xs.ndim == 1: + dx = xs[1] - xs[0] + dy = ys[1] - ys[0] + elif xs.ndim == 2: + ddx0 = xs[1, 0] - xs[0, 0] + ddy0 = ys[1, 0] - ys[0, 0] + ddx1 = xs[0, 1] - xs[0, 0] + ddy1 = ys[0, 1] - ys[0, 0] + dx = math.hypot(ddx1, ddy1) # always positive! + dy = math.hypot(ddx0, ddy0) + if self.rotation != 0: + # find grid top-down orientation rot = self.rotation acos = math.cos(math.radians(rot)) - # find grid top-down orientation if ( (acos < 0 and ddy0 > 0) or (acos > 0 and ddy0 < 0) @@ -545,6 +584,7 @@ def res(self) -> tuple[float, float]: ) ): dy = -1 * dy + self._res = dx, dy return dx, dy @property @@ -553,37 +593,56 @@ def rotation(self) -> float: NOTE: rotated rasters with a negative dx are not supported. """ - xs, ys = self.xcoords.data, self.ycoords.data - rot = 0 - if xs.ndim == 2: - ddx1 = xs[0, -1] - xs[0, 0] - ddy1 = ys[0, -1] - ys[0, 0] - if not np.isclose(ddx1, 0): - rot = math.degrees(math.atan(ddy1 / ddx1)) - else: - rot = -90 - if ddx1 < 0: - rot = 180 + rot - elif ddy1 < 0: - rot = 360 + rot + # cached transform is leading, then cached rotation, then coords + rot = None + if self._transform is not None: + if self._transform.b == self._transform.d == 0: + rot = 0 + elif self._transform.determinant >= 0: + rot = self._transform.rotation_angle + if rot is None: + if self._rotation is not None: + return self._rotation + elif self.xcoords.ndim == 1: + rot = 0 + elif self.xcoords.ndim == 2: + xs, ys = self.xcoords.data, self.ycoords.data + ddx1 = xs[0, -1] - xs[0, 0] + ddy1 = ys[0, -1] - ys[0, 0] + if not np.isclose(ddx1, 0): + rot = math.degrees(math.atan(ddy1 / ddx1)) + else: + rot = -90 + if ddx1 < 0: + rot = 180 + rot + elif ddy1 < 0: + rot = 360 + rot + self._rotation = rot return rot @property def origin(self) -> tuple[float, float]: """Return origin of grid (x0, y0) tuple.""" - xs, ys = self.xcoords.data, self.ycoords.data + # cached transform is leading, then cached origin, then coords x0, y0 = 0, 0 - dx, dy = self.res - if xs.ndim == 1: - x0, y0 = xs[0] - dx / 2, ys[0] - dy / 2 - elif xs.ndim == 2: - alpha = math.radians(self.rotation) - beta = math.atan(dx / dy) - c = math.hypot(dx, dy) / 2.0 - a = c * math.sin(beta - alpha) - b = c * math.cos(beta - alpha) - x0 = xs[0, 0] - np.sign(dy) * a - y0 = ys[0, 0] - np.sign(dy) * b + if self._transform is not None: + x0, y0 = self._transform * (0, 0) + elif self._origin is not None: + return self._origin + else: + xs, ys = self.xcoords.data, self.ycoords.data + dx, dy = self.res + if xs.ndim == 1: + x0, y0 = xs[0] - dx / 2, ys[0] - dy / 2 + elif xs.ndim == 2: + alpha = math.radians(self.rotation) + beta = math.atan(dx / dy) + c = math.hypot(dx, dy) / 2.0 + a = c * math.sin(beta - alpha) + b = c * math.cos(beta - alpha) + x0 = xs[0, 0] - np.sign(dy) * a + y0 = ys[0, 0] - np.sign(dy) * b + self._origin = x0, y0 return x0, y0 def _check_dimensions(self) -> None: @@ -1074,6 +1133,25 @@ def reclass_exact(x, ddict): ds_out[param] = da_param return ds_out + def clip(self, xslice: slice, yslice: slice): + """Clip object based on slices. + + Arguments + --------- + xslice, yslice : slice + x and y slices + + Returns + ------- + xarray.DataSet or DataArray + Data clipped to slices + """ + obj = self._obj.isel({self.x_dim: xslice, self.y_dim: yslice}) + obj.raster._crs = self._crs + translation = Affine.translation(xslice.start, yslice.start) + obj.raster._transform = self._transform * translation + return obj + def clip_bbox(self, bbox, align=None, buffer=0, crs=None): """Clip object based on a bounding box. @@ -1107,32 +1185,20 @@ def clip_bbox(self, bbox, align=None, buffer=0, crs=None): s = (s // align) * align e = (e // align + 1) * align n = (n // align + 1) * align - if self.rotation > 1: # update bbox based on clip to rotated box + xs, ys = [w, e], [s, n] + if self.rotation > 0: # update bbox based on clip to rotated box gdf_bbox = gpd.GeoDataFrame(geometry=[box(w, s, e, n)], crs=self.crs).clip( self.box ) - xs, ys = [w, e], [s, n] if not np.all(gdf_bbox.is_empty): xs, ys = zip(*gdf_bbox.dissolve().boundary[0].coords[:]) - cs, rs = ~self.transform * (np.array(xs), np.array(ys)) - c0 = max(round(int(cs.min() - buffer)), 0) - r0 = max(round(int(rs.min() - buffer)), 0) - c1 = int(round(cs.max() + buffer)) - r1 = int(round(rs.max() + buffer)) - return self._obj.isel( - {self.x_dim: slice(c0, c1), self.y_dim: slice(r0, r1)} - ) - else: - # TODO remove this part could also be based on row col just like the rotated - xres, yres = self.res - y0, y1 = (n, s) if yres < 0 else (s, n) - x0, x1 = (e, w) if xres < 0 else (w, e) - if buffer > 0: - y0 -= yres * buffer - y1 += yres * buffer - x0 -= xres * buffer - x1 += xres * buffer - return self._obj.sel({self.x_dim: slice(x0, x1), self.y_dim: slice(y0, y1)}) + cs, rs = ~self.transform * (np.array(xs), np.array(ys)) + # use round to get integer slices + c0 = max(int(round(cs.min() - buffer)), 0) + r0 = max(int(round(rs.min() - buffer)), 0) + c1 = int(round(cs.max() + buffer)) + r1 = int(round(rs.max() + buffer)) + return self.clip(slice(c0, c1), slice(r0, r1)) def clip_mask(self, da_mask: xr.DataArray, mask: bool = False): """Clip object to region with mask values greater than zero. @@ -1159,12 +1225,12 @@ def clip_mask(self, da_mask: xr.DataArray, mask: bool = False): raise ValueError("No valid values found in mask.") # clip row_slice, col_slice = ndimage.find_objects(da_mask.values.astype(np.uint8))[0] - obj_clip = self._obj.isel({self.x_dim: col_slice, self.y_dim: row_slice}) + obj = self.clip(xslice=col_slice, yslice=row_slice) if mask: # mask values and add mask coordinate mask_bin = da_mask.isel({self.x_dim: col_slice, self.y_dim: row_slice}) - obj_clip.coords["mask"] = xr.Variable(self.dims, mask_bin.values) - obj_clip = obj_clip.raster.mask(obj_clip.coords["mask"]) - return obj_clip + obj.coords["mask"] = xr.Variable(self.dims, mask_bin.values) + obj = obj.raster.mask(obj.coords["mask"]) + return obj def clip_geom(self, geom, align=None, buffer=0, mask=False): """Clip object to bounding box of the geometry. @@ -1274,6 +1340,8 @@ def rasterize( ) da_out.raster.set_nodata(nodata) da_out.raster.set_attrs(**self.attrs) + da_out.raster.set_crs(self.crs) + da_out.raster._transform = self._transform return da_out def rasterize_geometry( @@ -1373,8 +1441,9 @@ def rasterize_geometry( da_out = da_out.fillna(0) da_out.name = "fraction" - da_out.raster.set_crs(ds_like.raster.crs) da_out.raster.set_nodata(nodata) + da_out.raster.set_crs(self.crs) + da_out.raster._transform = self._transform # Rename da_area if name is not None: da_out.name = name @@ -1491,6 +1560,7 @@ def area_grid(self, dtype=np.float32): ) da_area.raster.set_nodata(0) da_area.raster.set_crs(self.crs) + da_area.raster._transform = self._transform da_area.attrs.update(unit="m2") return da_area.rename("area") @@ -1517,6 +1587,8 @@ def density_grid(self): unit = self._obj.attrs.get("unit", "") ds_out = self._obj / area ds_out.attrs.update(unit=f"{unit}.m-2") + ds_out.raster._crs = self._crs + ds_out.raster._transform = self._transform return ds_out def _dst_transform( @@ -1554,11 +1626,15 @@ def _dst_crs(self, dst_crs=None): # check CRS and transform set destination crs if missing if self.crs is None: raise ValueError("CRS is missing. Use set_crs function to resolve.") - if dst_crs == "utm": + if isinstance(dst_crs, pyproj.CRS): + return dst_crs + elif dst_crs == "utm": # make sure bounds are in EPSG:4326 - dst_crs = gis_utils.utm_crs(self.box.to_crs(4326).total_bounds) + dst_crs = gis_utils.utm_crs(self.transform_bounds(4326)) + elif dst_crs is not None: + dst_crs = CRS.from_user_input(dst_crs) else: - dst_crs = CRS.from_user_input(dst_crs) if dst_crs is not None else self.crs + dst_crs = self.crs return dst_crs def nearest_index( @@ -1660,6 +1736,7 @@ def nearest_index( ) index.raster.set_crs(dst_crs) index.raster.set_nodata(-1) + index.raster._tranform = dst_transform return index @@ -1716,6 +1793,7 @@ def from_numpy(data, transform, nodata=None, attrs=None, crs=None): da.attrs.update(attrs) if crs is not None: da.raster.set_crs(input_crs=crs) + da.raster._transform = transform return da @property @@ -1967,6 +2045,7 @@ def _reproj(da, **kwargs): da_temp = da_temp.chunk(chunks) da_reproj = _da.map_blocks(_reproj, kwargs=reproj_kwargs, template=da_temp) da_reproj.raster.set_crs(dst_crs) + da_reproj.raster._transform = dst_transform return da_reproj.raster.reset_spatial_dims_attrs() def reproject_like(self, other, method="nearest"): @@ -2164,6 +2243,7 @@ def interpolate_na( ) interp_array.raster.set_nodata(self.nodata) interp_array.raster.set_crs(self.crs) + interp_array.raster._transform = self.transform return interp_array def to_xyz_tiles( @@ -2223,12 +2303,13 @@ def to_xyz_tiles( # create temp tile os.makedirs(ssd, exist_ok=True) - temp = obj[u : u + h, l : l + w] + temp = obj[u : u + h, l : l + w].load() if zl != 0: temp = temp.coarsen( {x_dim: 2**diff, y_dim: 2**diff}, boundary="pad" ).mean() temp.raster.set_nodata(nodata) + temp.raster._crs = obj.raster.crs if driver == "netcdf4": path = join(ssd, f"{row}.nc") @@ -2267,7 +2348,7 @@ def to_xyz_tiles( def to_slippy_tiles( self, root: Path | str, - reproj_method: str = "average", + reproj_method: str = "bilinear", min_lvl: int = None, max_lvl: int = None, driver="png", @@ -2284,9 +2365,9 @@ def to_slippy_tiles( root : Path | str Path where the database will be saved reproj_method : str, optional - How to resample the data when downscaling. - E.g. 'nearest' for resampling with the nearest value - (This is only used for the first/ highest zoomlevel) + How to resample the data at the finest zoom level, by default 'bilinear'. + See :py:meth:`~hydromt.raster.RasterDataArray.reproject` for existing + methods. min_lvl, max_lvl : int, optional The minimum and maximum zoomlevel to be produced. If None, the zoomlevels will be determined based on the data resolution @@ -2298,46 +2379,40 @@ def to_slippy_tiles( norm : object, optional A matplotlib Normalize object that defines a range between a maximum and minimum value - **kwargs Key-word arguments to write file for netcdf4, these are passed to ~:py:meth:xarray.DataArray.to_netcdf: for GTiff, these are passed to ~:py:meth:hydromt.RasterDataArray.to_raster: for png, these are passed to ~:py:meth:PIL.Image.Image.save: """ - # for now these are optional dependencies - if driver.lower() == "png": + # check driver + ldriver = driver.lower() + ext = {"png": "png", "netcdf4": "nc", "gtiff": "tif"}.get(ldriver, None) + if ext is None: + raise ValueError(f"Unkown file driver {driver}, use png, netcdf4 or GTiff") + if ldriver == "png": try: + # optional imports import matplotlib.pyplot as plt from PIL import Image except ImportError: raise ImportError("matplotlib and pillow are required for png output") - # Fixed pixel size and CRS for XYZ tiles - pxs = 256 - # Extent in y-direction for pseudo mercator (EPSG:3857) - y_ext = math.atan(math.sinh(math.pi)) * (180 / math.pi) - y_ext_pm = mct.xy(0, y_ext)[1] - - crs = CRS.from_epsg(3857) - ext = {"png": "png", "netcdf4": "nc", "gtiff": "tif"}.get(driver.lower(), None) - if ext is None: - raise ValueError(f"Unkown file driver {driver}, use png, netcdf4 or GTiff") - - # Object to local variable, also transpose it and extract some meta + # check data and make sure the y-axis as first dimension if self._obj.ndim != 2: raise ValueError("Only 2d DataArrays are accepted.") - # make sure the y-axis as first dimension and nodata values set to nan - obj = self.mask_nodata().transpose(self.y_dim, self.x_dim) - # make sure dataarray has a name + obj = self._obj.transpose(self.y_dim, self.x_dim) name = obj.name or "data" obj.name = name + # get some properties of the data obj_res = obj.raster.res[0] obj_bounds = list(obj.raster.bounds) + nodata = self.nodata + dtype = obj.dtype # colormap output - if cmap is not None and driver != "png": + if cmap is not None and ldriver != "png": raise ValueError("Colormap is only supported for png output") if isinstance(cmap, str): cmap = plt.get_cmap(cmap) @@ -2346,11 +2421,19 @@ def to_slippy_tiles( vmin=obj.min().load().item(), vmax=obj.max().load().item() ) - # for now we assume output has float32 dtype + # some tile size, bounds and CRS properties + # Fixed pixel size and CRS for XYZ tiles + pxs = 256 + crs = CRS.from_epsg(3857) + # Extent in y-direction for pseudo mercator (EPSG:3857) + y_ext = math.atan(math.sinh(math.pi)) * (180 / math.pi) + y_ext_pm = mct.xy(0, y_ext)[1] + + # default kwargs for writing files kwargs0 = { - "netcdf4": {"encoding": {name: {"dtype": "float32", "zlib": True}}}, - "gtiff": {"driver": "GTiff", "compress": "deflate", "dtype": "float32"}, - }.get(driver.lower(), {}) + "netcdf4": {"encoding": {name: {"zlib": True}}}, + "gtiff": {"driver": "GTiff", "compress": "deflate"}, + }.get(ldriver, {}) kwargs = {**kwargs0, **kwargs} # Setting up information for determination of tile windows @@ -2370,34 +2453,35 @@ def to_slippy_tiles( # Calculate min/max zoomlevel based if max_lvl is None: # calculate max zoomlevel close to native resolution # Determine the max number of zoom levels with the resolution - # This section is purely for the resolution - obj_clipped_to_pseudo = obj.raster.clip_bbox( - (-180, -y_ext, 180, y_ext), - crs="EPSG:4326", - ) - tr_3857 = rasterio.warp.calculate_default_transform( - obj_clipped_to_pseudo.raster.crs, - "EPSG:3857", - *obj_clipped_to_pseudo.shape, - *obj_clipped_to_pseudo.raster.bounds, - )[0] - - del obj_clipped_to_pseudo + # using the resolution determined by default transform + if self.crs != crs: + obj_clipped_to_pseudo = obj.raster.clip_bbox( + (-180, -y_ext, 180, y_ext), + crs="EPSG:4326", + ) + tr_3857 = rasterio.warp.calculate_default_transform( + obj_clipped_to_pseudo.raster.crs, + "EPSG:3857", + *obj_clipped_to_pseudo.shape, + *obj_clipped_to_pseudo.raster.bounds, + )[0] + dres = tr_3857[0] + del obj_clipped_to_pseudo + else: + dres = obj_res[0] # Calculate the maximum zoom level - dres = tr_3857[0] max_lvl = int( math.ceil((math.log10((y_ext_pm * 2) / (dres * pxs)) / math.log10(2))) ) if min_lvl is None: # calculate min zoomlevel based on the data extent - min_lvl = mct.bounding_tile(*bounds_4326).z + min_lvl = mct.bounding_tile(*bounds_4326, truncate=True).z # Loop through the zoom levels zoom_levels = {} logger.info(f"Producing tiles from zoomlevel {min_lvl} to {max_lvl}") for zl in range(max_lvl, min_lvl - 1, -1): fns = [] - # Go through the zoomlevels for i, tile in enumerate(mct.tiles(*bounds_4326, zl, truncate=True)): ssd = Path(root, str(zl), f"{tile.x}") os.makedirs(ssd, exist_ok=True) @@ -2405,15 +2489,27 @@ def to_slippy_tiles( if i == 0: # zoom level : resolution in meters zoom_levels[zl] = abs(tile_bounds[2] - tile_bounds[0]) / pxs if zl == max_lvl: - # For the first zoomlevel, we can just clip the data - # does this need a try/except? + # For the first zoomlevel, first clip the data src_tile = obj.raster.clip_bbox( - tile_bounds, crs=crs, buffer=2, align=True + tile_bounds, crs=crs, buffer=1, align=True + ).load() + # then reproject the data to the tile + dst_transform = rasterio.transform.from_bounds( + *tile_bounds, pxs, pxs ) + dst_tile = src_tile.raster.reproject( + dst_crs=crs, + dst_transform=dst_transform, + dst_width=pxs, + dst_height=pxs, + method=reproj_method, + ) + dst_tile.name = name + dst_data = dst_tile.values else: # Every tile from this level has 4 child tiles on the previous lvl # Create a temporary array, 4 times the size of a tile - temp = np.full((pxs * 2, pxs * 2), np.nan, dtype=np.float64) + temp = np.full((pxs * 2, pxs * 2), nodata, dtype=dtype) for ic, child in enumerate(mct.children(tile)): fn = Path(root, str(child.z), str(child.x), f"{child.y}.{ext}") # Check if the file is really there, if not: it was not written @@ -2422,64 +2518,58 @@ def to_slippy_tiles( # order: top-left, top-right, bottom-right, bottom-left yslice = slice(0, pxs) if ic in [0, 1] else slice(pxs, None) xslice = slice(0, pxs) if ic in [0, 3] else slice(pxs, None) - if driver == "netcdf4": - with xr.open_dataset(fn) as ds: + if ldriver == "netcdf4": + with xr.open_dataset(fn, mask_and_scale=False) as ds: temp[yslice, xslice] = ds[name].values - elif driver == "GTiff": - with rioxarray.open_rasterio( - fn, parse_coordinates=False - ) as da: - temp[yslice, xslice] = da.squeeze(drop=True).values - elif driver == "png": + elif ldriver == "gtiff": + with rasterio.open(fn) as src: + temp[yslice, xslice] = src.read(1) + elif ldriver == "png": if cmap is not None: fn_bin = str(fn).replace(f".{ext}", ".bin") with open(fn_bin, "r") as f: - data = np.fromfile(f, "f4").reshape((pxs, pxs)) + data = np.fromfile(f, dtype).reshape((pxs, pxs)) os.remove(fn_bin) # clean up else: - data = rgba2elevation(np.array(Image.open(fn))) + im = np.array(Image.open(fn)) + data = rgba2elevation(im, nodata=nodata, dtype=dtype) temp[yslice, xslice] = data - # create a dataarray from the temporary array - src_transform = rasterio.transform.from_bounds( - *tile_bounds, pxs * 2, pxs * 2 - ) - src_tile = RasterDataArray.from_numpy( - temp, src_transform, crs=crs, nodata=np.nan - ) - src_tile.name = name - - # reproject the data to the tile / coares resolution - dst_transform = rasterio.transform.from_bounds(*tile_bounds, pxs, pxs) - dst_tile = src_tile.raster.reproject( - dst_crs=crs, - dst_transform=dst_transform, - dst_width=pxs, - dst_height=pxs, - method=reproj_method, - ) + # coarsen the data using mean + temp = temp.reshape((pxs, 2, pxs, 2)) + dst_data = np.ma.masked_values(temp, nodata).mean(axis=(1, 3)).data + if ldriver != "png": + # create a dataarray from the temporary array + dst_transform = rasterio.transform.from_bounds( + *tile_bounds, pxs, pxs + ) + dst_tile = RasterDataArray.from_numpy( + dst_data, dst_transform, crs=crs, nodata=nodata + ) + dst_tile.name = name + # write the data to file fn_out = Path(ssd, f"{tile.y}.{ext}") fns.append(fn_out) - if driver.lower() == "png": + if ldriver == "png": if cmap is not None and zl != min_lvl: # create temp bin file with data for upsampling fn_bin = str(fn_out).replace(f".{ext}", ".bin") - dst_tile.values.astype("f4").tofile(fn_bin) - rgba = cmap(norm(dst_tile.values), bytes=True) + dst_data.astype(dtype).tofile(fn_bin) + rgba = cmap(norm(dst_data), bytes=True) else: # Create RGBA bands from the data and save it as png - rgba = elevation2rgba(dst_tile.values) + rgba = elevation2rgba(dst_data, nodata=nodata) Image.fromarray(rgba).save(fn_out, **kwargs) - elif driver.lower() == "netcdf4": + elif ldriver == "netcdf4": # write the data to netcdf dst_tile = dst_tile.raster.gdal_compliant() dst_tile.to_netcdf(fn_out, **kwargs) - elif driver.lower() == "gtiff": + elif ldriver == "gtiff": # write the data to geotiff dst_tile.raster.to_raster(fn_out, **kwargs) # Write files to txt and create a vrt using GDAL - if driver.lower() != "png": + if ldriver != "png": txt_fn = Path(root, str(zl), "filelist.txt") vrt_fn = Path(root, f"lvl{zl}.vrt") with open(txt_fn, "w") as f: @@ -2488,7 +2578,7 @@ def to_slippy_tiles( gis_utils.create_vrt(vrt_fn, file_list_path=txt_fn) # Write a quick yaml for the database - if driver.lower() != "png": + if ldriver != "png": yml = { "crs": 3857, "data_type": "RasterDataset", diff --git a/hydromt/utils.py b/hydromt/utils.py index 06739d5ae..adcb8686e 100644 --- a/hydromt/utils.py +++ b/hydromt/utils.py @@ -43,21 +43,22 @@ def partition_dictionaries(left, right): return common, left_less_right, right_less_left -def elevation2rgba(val): +def elevation2rgba(val, nodata=np.nan): """Convert elevation to rgb tuple.""" val += 32768 r = np.floor(val / 256).astype(np.uint8) g = np.floor(val % 256).astype(np.uint8) b = np.floor((val - np.floor(val)) * 256).astype(np.uint8) - a = np.where(np.isnan(val), 0, 255).astype(np.uint8) + mask = np.isnan(val) if np.isnan(nodata) else val == nodata + a = np.where(mask, 0, 255).astype(np.uint8) return np.stack((r, g, b, a), axis=2) -def rgba2elevation(rgba: np.ndarray): +def rgba2elevation(rgba: np.ndarray, nodata=np.nan, dtype=np.float32): """Convert rgb tuple to elevation.""" r, g, b, a = np.split(rgba, 4, axis=2) val = (r * 256 + g + b / 256) - 32768 - return np.where(a == 0, np.nan, val).squeeze() + return np.where(a == 0, nodata, val).squeeze().astype(dtype) def _dict_pprint(d): diff --git a/tests/test_raster.py b/tests/test_raster.py index a4ed0b70b..36059abb3 100644 --- a/tests/test_raster.py +++ b/tests/test_raster.py @@ -53,6 +53,9 @@ def test_raster_properties(origin, rotation, res, shape, bounds): assert np.allclose(da.raster.box.total_bounds, da.raster.bounds) assert np.allclose(bounds, da.raster.internal_bounds) assert da.raster.box.crs == da.raster.crs + # attributes do not persist after slicing + assert da[:2, :2].raster._transform is None + assert da[:2, :2].raster._crs is None @pytest.mark.parametrize(("transform", "shape"), testdata) @@ -265,6 +268,14 @@ def test_clip(transform, shape): # create rasterdataarray with crs da = raster.full_from_transform(transform, shape, nodata=1, name="test", crs=4326) da.raster.set_nodata(0) + # attributes do not persist with xarray slicing + da1 = da[:2, :2] + assert da1.raster._transform is None + assert da1.raster._crs is None + # attributes do persisit when slicing using clip_bbox + raster1d = da.raster.clip(slice(1, 2), slice(0, 2)) + assert raster1d.raster._crs is not None + assert raster1d.raster.transform is not None # create gdf covering approx half raster w, s, _, n = da.raster.bounds e, _ = da.raster.transform * (shape[1] // 2, shape[0] // 2) @@ -295,11 +306,15 @@ def test_clip(transform, shape): if da.raster.rotation != 0: return assert np.all(np.isclose(da_clip0.raster.bounds, gdf.total_bounds)) - # test bbox - align - align = np.round(abs(da.raster.res[0] * 2), 2) - da_clip = da.raster.clip_bbox(gdf.total_bounds, align=align) - dalign = np.round(da_clip.raster.bounds[2], 2) % align - assert np.isclose(dalign, 0) or np.isclose(dalign, align) + + +def test_clip_align(rioda): + # test align + bbox = (3.5, -10.5, 5.5, -9.5) + da_clip = rioda.raster.clip_bbox(bbox) + assert np.all(np.isclose(da_clip.raster.bounds, bbox)) + da_clip = rioda.raster.clip_bbox(bbox, align=1) + assert da_clip.raster.bounds == (3, -11, 6, -9) def test_clip_errors(rioda): @@ -552,7 +567,7 @@ def test_to_slippy_tiles(tmpdir, rioda_large): fn = join(png_dir, "7", "64", "64.png") im = np.array(Image.open(fn)) assert im.shape == (256, 256, 4) - assert all(im[0, 0, :] == [128, 0, 131, 255]) + assert all(im[0, 0, :] == [128, 0, 132, 255]) # test with cmap png_dir = join(tmpdir, "tiles_png_cmap") @@ -560,7 +575,7 @@ def test_to_slippy_tiles(tmpdir, rioda_large): fn = join(png_dir, "7", "64", "64.png") im = np.array(Image.open(fn)) assert im.shape == (256, 256, 4) - assert all(im[0, 0, :] == [31, 148, 139, 255]) + assert all(im[0, 0, :] == [32, 143, 140, 255]) # gtiff tif_dir = join(tmpdir, "tiles_tif")