Skip to content

Commit

Permalink
Rasterio fix (#63)
Browse files Browse the repository at this point in the history
* add hydrofabric version update checking

* speed up s3fs using parallel download of large files

* remove rioxarray

* improve cli response time import modules asynchronously

* improve first launch print statements

* update cli progress bars

* add elapsed and remaining time to forcing progress bar
  • Loading branch information
JoshCu authored Nov 26, 2024
1 parent 61f85af commit 5199725
Show file tree
Hide file tree
Showing 8 changed files with 321 additions and 68 deletions.
2 changes: 2 additions & 0 deletions modules/data_processing/file_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ class file_paths:
"""
config_file = Path("~/.ngiab/preprocessor").expanduser()
hydrofabric_dir = Path("~/.ngiab/hydrofabric/v2.2").expanduser()
hydrofabric_download_log = Path("~/.ngiab/hydrofabric/v2.2/download_log.json").expanduser()
no_update_hf = Path("~/.ngiab/hydrofabric/v2.2/no_update").expanduser()
cache_dir = Path("~/.ngiab/zarr_cache").expanduser()
output_dir = None
data_sources = Path(__file__).parent.parent / "data_sources"
Expand Down
57 changes: 49 additions & 8 deletions modules/data_processing/forcings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@
from data_processing.file_paths import file_paths
from data_processing.zarr_utils import get_forcing_data
from exactextract import exact_extract
from exactextract.raster import NumPyRasterSource
from rich.progress import (
Progress,
BarColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)


logger = logging.getLogger(__name__)
# Suppress the specific warning from numpy to keep the cli output clean
Expand All @@ -27,6 +36,7 @@
"ignore", message="'GeoDataFrame.swapaxes' is deprecated", category=FutureWarning
)


def weighted_sum_of_cells(flat_raster: np.ndarray, cell_ids: np.ndarray , factors: np.ndarray):
# Create an output array initialized with zeros
# dimensions are raster[time][x*y]
Expand All @@ -37,10 +47,17 @@ def weighted_sum_of_cells(flat_raster: np.ndarray, cell_ids: np.ndarray , factor
return result


def get_cell_weights(raster, gdf):
def get_cell_weights(raster, gdf, wkt):
# Get the cell weights for each divide
xmin = raster.x[0]
xmax = raster.x[-1]
ymin = raster.y[0]
ymax = raster.y[-1]
rastersource = NumPyRasterSource(
raster["RAINRATE"], srs_wkt=wkt, xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax
)
output = exact_extract(
raster["RAINRATE"],
rastersource,
gdf,
["cell_id", "coverage"],
include_cols=["divide_id"],
Expand Down Expand Up @@ -109,11 +126,11 @@ def process_chunk_shared(variable, times, shm_name, shape, dtype, chunk):

def get_cell_weights_parallel(gdf, input_forcings, num_partitions):
gdf_chunks = np.array_split(gdf, num_partitions)
wkt = gdf.crs.to_wkt()
one_timestep = input_forcings.isel(time=0).compute()
with multiprocessing.Pool() as pool:
args = [(one_timestep, gdf_chunk) for gdf_chunk in gdf_chunks]
args = [(one_timestep, gdf_chunk, wkt) for gdf_chunk in gdf_chunks]
catchments = pool.starmap(get_cell_weights, args)

return pd.concat(catchments)


Expand All @@ -139,20 +156,38 @@ def compute_zonal_stats(
"V2D": "VGRD_10maboveground",
}

results = []
cat_chunks = np.array_split(catchments, num_partitions)
forcing_times = merged_data.time.values

progress = Progress(
TextColumn("[progress.description]{task.description}"),
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TextColumn("{task.completed}/{task.total}"),
"•",
TextColumn(" Elapsed Time:"),
TimeElapsedColumn(),
TextColumn(" Remaining Time:"),
TimeRemainingColumn(),
)

timer = time.perf_counter()
variable_task = progress.add_task(
"[cyan]Processing variables...", total=len(variables), elapsed=0
)
progress.start()
for variable in variables.keys():
progress.update(variable_task, advance=1)
progress.update(variable_task, description=f"Processing {variable}")

if variable not in merged_data.data_vars:
logger.warning(f"Variable {variable} not in forcings, skipping")
continue

# to make sure this fits in memory, we need to chunk the data
time_chunks = get_index_chunks(merged_data[variable])

chunk_task = progress.add_task("[purple] processing chunks", total=len(time_chunks))
for i, times in enumerate(time_chunks):
progress.update(chunk_task, advance=1)
start, end = times
# select the chunk of time we want to process
data_chunk = merged_data[variable].isel(time=slice(start,end))
Expand Down Expand Up @@ -184,8 +219,14 @@ def compute_zonal_stats(
xr.concat(datasets, dim="time").to_netcdf(forcings_dir / f"{variable}.nc")
for file in forcings_dir.glob("temp/*.nc"):
file.unlink()
progress.remove_task(chunk_task)
progress.update(
variable_task,
description=f"Forcings processed in {time.perf_counter() - timer:2f} seconds",
)
progress.stop()
logger.info(
f"Forcing generation complete! Zonal stats computed in {time.time() - timer_start} seconds"
f"Forcing generation complete! Zonal stats computed in {time.time() - timer_start:2f} seconds"
)
write_outputs(forcings_dir, variables)

Expand Down
5 changes: 4 additions & 1 deletion modules/data_processing/gpkg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def verify_indices(gpkg: str = file_paths.conus_hydrofabric) -> None:
Verify that the indices in the specified geopackage are correct.
If they are not, create the correct indices.
"""
logger.info("Building database indices")
logger.debug("Building database indices")
new_indicies = [
'CREATE INDEX "diid" ON "divides" ( "divide_id" ASC );',
'CREATE INDEX "ditid" ON "divides" ( "toid" ASC );',
Expand All @@ -55,6 +55,9 @@ def verify_indices(gpkg: str = file_paths.conus_hydrofabric) -> None:
con = sqlite3.connect(gpkg)
indices = con.execute("SELECT name FROM sqlite_master WHERE type = 'index'").fetchall()
indices = [x[0] for x in indices]
missing = [x for x in new_indicies if x.split('"')[1] not in indices]
if len(missing) > 0:
logger.info("Creating indices")
for index in new_indicies:
if index.split('"')[1] not in indices:
logger.info(f"Creating index {index}")
Expand Down
77 changes: 77 additions & 0 deletions modules/data_processing/s3fs_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from s3fs import S3FileSystem
from s3fs.core import _error_wrapper, version_id_kw
from typing import Optional
import asyncio


class S3ParallelFileSystem(S3FileSystem):
"""S3FileSystem subclass that supports parallel downloads"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

async def _cat_file(
self,
path: str,
version_id: Optional[str] = None,
start: Optional[int] = None,
end: Optional[int] = None,
) -> bytes:
bucket, key, vers = self.split_path(path)
version_kw = version_id_kw(version_id or vers)

# If start/end specified, use single range request
if start is not None or end is not None:
head = {"Range": await self._process_limits(path, start, end)}
return await self._download_chunk(bucket, key, head, version_kw)

# For large files, use parallel downloads
try:
obj_size = (
await self._call_s3(
"head_object", Bucket=bucket, Key=key, **version_kw, **self.req_kw
)
)["ContentLength"]
except Exception as e:
# Fall back to single request if HEAD fails
return await self._download_chunk(bucket, key, {}, version_kw)

CHUNK_SIZE = 1 * 1024 * 1024 # 1MB chunks
if obj_size <= CHUNK_SIZE:
return await self._download_chunk(bucket, key, {}, version_kw)

# Calculate chunks for parallel download
chunks = []
for start in range(0, obj_size, CHUNK_SIZE):
end = min(start + CHUNK_SIZE - 1, obj_size - 1)
range_header = f"bytes={start}-{end}"
chunks.append({"Range": range_header})

# Download chunks in parallel
async def download_all_chunks():
tasks = [
self._download_chunk(bucket, key, chunk_head, version_kw) for chunk_head in chunks
]
chunks_data = await asyncio.gather(*tasks)
return b"".join(chunks_data)

return await _error_wrapper(download_all_chunks, retries=self.retries)

async def _download_chunk(self, bucket: str, key: str, head: dict, version_kw: dict) -> bytes:
"""Helper function to download a single chunk"""

async def _call_and_read():
resp = await self._call_s3(
"get_object",
Bucket=bucket,
Key=key,
**version_kw,
**head,
**self.req_kw,
)
try:
return await resp["Body"].read()
finally:
resp["Body"].close()

return await _error_wrapper(_call_and_read, retries=self.retries)
19 changes: 12 additions & 7 deletions modules/data_processing/zarr_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@
import geopandas as gpd
import numpy as np
import s3fs
from data_processing.s3fs_utils import S3ParallelFileSystem
import xarray as xr
from dask.distributed import Client, LocalCluster, progress
from data_processing.file_paths import file_paths
from fsspec.mapping import FSMap


logger = logging.getLogger(__name__)

def open_s3_store(url: str) -> FSMap:
"""Open an s3 store from a given url."""
return s3fs.S3Map(url, s3=s3fs.S3FileSystem(anon=True))

def load_zarr_datasets() -> xr.Dataset:
"""Load zarr datasets from S3 within the specified time range."""
Expand All @@ -30,14 +29,18 @@ def load_zarr_datasets() -> xr.Dataset:
f"s3://noaa-nwm-retrospective-3-0-pds/CONUS/zarr/forcing/{var}.zarr"
for var in forcing_vars
]
s3_stores = [open_s3_store(url) for url in s3_urls]
dataset = xr.open_mfdataset(s3_stores, parallel=True, engine="zarr")
# default cache is readahead which is detrimental to performance in this case
fs = S3ParallelFileSystem(anon=True, default_cache_type="none") # default_block_size
s3_stores = [s3fs.S3Map(url, s3=fs) for url in s3_urls]
# the cache option here just holds accessed data in memory to prevent s3 being queried multiple times
# most of the data is read once and written to disk but some of the coordinate data is read multiple times
dataset = xr.open_mfdataset(s3_stores, parallel=True, engine="zarr", cache=True)
return dataset


def validate_time_range(dataset: xr.Dataset, start_time: str, end_time: str) -> Tuple[str, str]:
end_time_in_dataset = dataset.time[-1].values
start_time_in_dataset = dataset.time[0].values
end_time_in_dataset = dataset.time.isel(time=-1).values
start_time_in_dataset = dataset.time.isel(time=0).values
if np.datetime64(start_time) < start_time_in_dataset:
logger.warning(
f"provided start {start_time} is before the start of the dataset {start_time_in_dataset}, selecting from {start_time_in_dataset}"
Expand Down Expand Up @@ -130,11 +133,13 @@ def get_forcing_data(

if merged_data is None:
logger.info("Loading zarr stores")
# create new event loop
lazy_store = load_zarr_datasets()
logger.debug("Got zarr stores")
clipped_store = clip_dataset_to_bounds(lazy_store, gdf.total_bounds, start_time, end_time)
logger.info("Clipped forcing data to bounds")
merged_data = compute_store(clipped_store, forcing_paths.cached_nc_file)
logger.info("Forcing data loaded and cached")
# close the event loop

return merged_data
Loading

0 comments on commit 5199725

Please sign in to comment.