Skip to content

Commit

Permalink
Add integration testing to validate ragged array dims and orig dims m…
Browse files Browse the repository at this point in the history
…atch.
  • Loading branch information
kevinsantana11 committed Jan 16, 2025
1 parent 1040fe0 commit 8f1bfb1
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 13 deletions.
24 changes: 14 additions & 10 deletions clouddrift/adapters/ibtracs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,14 @@ def to_raggedarray(
Specify the dataset kind to retrieve. Specifying the kind can speed up execution time of specific querries
and operations. Default is "LAST_3_YEARS".
tmp_path: str, default adapter temp path (default)
Temporary path where intermediary files are stored. Default is ${osSpecificLocation}/clouddrift/ibtracs/.
Temporary path where intermediary files are stored. Default is ${osSpecificTempFileLocation}/clouddrift/ibtracs/.
Returns
-------
xarray.Dataset
IBTRACS dataset as a ragged array.
"""

os.makedirs(tmp_path, exist_ok=True)
src_url = _get_source_url(version, kind)

filename = src_url.split("/")[-1]
dst_path = os.path.join(tmp_path, filename)
download_with_progress([(src_url, dst_path)])

ds = xr.open_dataset(dst_path, engine="netcdf4")
ds = _get_original_dataset(version, kind, tmp_path)
ds = ds.rename_dims({"date_time": "obs"})

vars = list[Hashable]()
Expand Down Expand Up @@ -105,6 +97,18 @@ def to_raggedarray(
)
return ra.to_xarray()

def _get_original_dataset(
version: _Version, kind: _Kind, tmp_path: str = _DEFAULT_FILE_PATH
) -> xr.Dataset:
os.makedirs(tmp_path, exist_ok=True)
src_url = _get_source_url(version, kind)

filename = src_url.split("/")[-1]
dst_path = os.path.join(tmp_path, filename)
download_with_progress([(src_url, dst_path)])

return xr.open_dataset(dst_path, engine="netcdf4")


def _rowsize(idx: int, **kwargs):
ds: xr.Dataset | None = kwargs.get("dataset")
Expand Down
17 changes: 14 additions & 3 deletions tests/datasets_tests.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

import tests.utils as testutils
from clouddrift import datasets
from clouddrift import datasets, adapters
from clouddrift.ragged import apply_ragged, subset


Expand All @@ -15,8 +15,19 @@ def test_gdp6h(self):
self.assertTrue(ds)

def test_ibtracs(self):
with datasets.ibtracs() as ds:
self.assertTrue(ds)
options = dict(version="v04r01", kind="LAST_3_YEARS")
ragged_ds = adapters.ibtracs.to_raggedarray(**options)
ds = adapters.ibtracs._get_original_dataset(**options)
ragged_ds_first = subset(
ragged_ds,
{
"storm": 0
},
row_dim_name="storm",
rowsize_var_name="numobs"
)
ds_first = ds.sel(storm=0)
np.allclose(ds_first.usa_r34[:10].data, ragged_ds_first.usa_r34[:10].data, equal_nan=True)

def test_glad(self):
with datasets.glad() as ds:
Expand Down

0 comments on commit 8f1bfb1

Please sign in to comment.