Skip to content

Commit

Permalink
implement map_over_datasets kwargs (pydata#10012)
Browse files Browse the repository at this point in the history
* add kwargs to map_over_datasets (similar to apply_ufunc), add test.

* try to fix typing

* improve typing and simplify kwargs-handling per review suggestions

* apply changes to DataTree.map_over_datasets

* add whats-new.rst entry

* Update xarray/core/datatree_mapping.py

Co-authored-by: Mathias Hauser <[email protected]>

* add suggestions from review.

---------

Co-authored-by: Mathias Hauser <[email protected]>
  • Loading branch information
kmuehlbauer and mathause authored Feb 10, 2025
1 parent c8f7dc6 commit 1189240
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 9 deletions.
3 changes: 2 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ v2025.02.0 (unreleased)

New Features
~~~~~~~~~~~~

- Allow kwargs in :py:meth:`DataTree.map_over_datasets` and :py:func:`map_over_datasets` (:issue:`10009`, :pull:`10012`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
8 changes: 6 additions & 2 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,6 +1429,7 @@ def map_over_datasets(
self,
func: Callable,
*args: Any,
kwargs: Mapping[str, Any] | None = None,
) -> DataTree | tuple[DataTree, ...]:
"""
Apply a function to every dataset in this subtree, returning a new tree which stores the results.
Expand All @@ -1446,7 +1447,10 @@ def map_over_datasets(
Function will not be applied to any nodes without datasets.
*args : tuple, optional
Positional arguments passed on to `func`.
Positional arguments passed on to `func`. Any DataTree arguments will be
converted to Dataset objects via `.dataset`.
kwargs : dict, optional
Optional keyword arguments passed directly to ``func``.
Returns
-------
Expand All @@ -1459,7 +1463,7 @@ def map_over_datasets(
"""
# TODO this signature means that func has no way to know which node it is being called upon - change?
# TODO fix this typing error
return map_over_datasets(func, self, *args)
return map_over_datasets(func, self, *args, kwargs=kwargs)

def pipe(
self, func: Callable | tuple[Callable, str], *args: Any, **kwargs: Any
Expand Down
30 changes: 24 additions & 6 deletions xarray/core/datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,38 @@


@overload
def map_over_datasets(func: Callable[..., Dataset | None], *args: Any) -> DataTree: ...
def map_over_datasets(
func: Callable[
...,
Dataset | None,
],
*args: Any,
kwargs: Mapping[str, Any] | None = None,
) -> DataTree: ...


@overload
def map_over_datasets(
func: Callable[..., tuple[Dataset | None, Dataset | None]], *args: Any
func: Callable[..., tuple[Dataset | None, Dataset | None]],
*args: Any,
kwargs: Mapping[str, Any] | None = None,
) -> tuple[DataTree, DataTree]: ...


# add an expect overload for the most common case of two return values
# (python typing does not have a way to match tuple lengths in general)
@overload
def map_over_datasets(
func: Callable[..., tuple[Dataset | None, ...]], *args: Any
func: Callable[..., tuple[Dataset | None, ...]],
*args: Any,
kwargs: Mapping[str, Any] | None = None,
) -> tuple[DataTree, ...]: ...


def map_over_datasets(
func: Callable[..., Dataset | None | tuple[Dataset | None, ...]], *args: Any
func: Callable[..., Dataset | None | tuple[Dataset | None, ...]],
*args: Any,
kwargs: Mapping[str, Any] | None = None,
) -> DataTree | tuple[DataTree, ...]:
"""
Applies a function to every dataset in one or more DataTree objects with
Expand Down Expand Up @@ -62,12 +75,14 @@ def map_over_datasets(
func : callable
Function to apply to datasets with signature:
`func(*args: Dataset) -> Union[Dataset, tuple[Dataset, ...]]`.
`func(*args: Dataset, **kwargs) -> Union[Dataset, tuple[Dataset, ...]]`.
(i.e. func must accept at least one Dataset and return at least one Dataset.)
*args : tuple, optional
Positional arguments passed on to `func`. Any DataTree arguments will be
converted to Dataset objects via `.dataset`.
kwargs : dict, optional
Optional keyword arguments passed directly to ``func``.
Returns
-------
Expand All @@ -85,6 +100,9 @@ def map_over_datasets(

from xarray.core.datatree import DataTree

if kwargs is None:
kwargs = {}

# Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees
# We don't know which arguments are DataTrees so we zip all arguments together as iterables
# Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return
Expand All @@ -100,7 +118,7 @@ def map_over_datasets(
node_dataset_args.insert(i, arg)

func_with_error_context = _handle_errors_with_path_context(path)(func)
results = func_with_error_context(*node_dataset_args)
results = func_with_error_context(*node_dataset_args, **kwargs)
out_data_objects[path] = results

num_return_values = _check_all_return_values(out_data_objects)
Expand Down
23 changes: 23 additions & 0 deletions xarray/tests/test_datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ def test_single_tree_arg_plus_arg(self, create_test_datatree):
result_tree = map_over_datasets(lambda x, y: x * y, 10.0, dt)
assert_equal(result_tree, expected)

def test_single_tree_arg_plus_kwarg(self, create_test_datatree):
dt = create_test_datatree()
expected = create_test_datatree(modify=lambda ds: (10.0 * ds))

def multiply_by_kwarg(ds, **kwargs):
ds = ds * kwargs.pop("multiplier")
return ds

result_tree = map_over_datasets(
multiply_by_kwarg, dt, kwargs=dict(multiplier=10.0)
)
assert_equal(result_tree, expected)

def test_multiple_tree_args(self, create_test_datatree):
dt1 = create_test_datatree()
dt2 = create_test_datatree()
Expand Down Expand Up @@ -138,6 +151,16 @@ def multiply(ds, times):
result_tree = dt.map_over_datasets(multiply, 10.0)
assert_equal(result_tree, expected)

def test_tree_method_with_kwarg(self, create_test_datatree):
dt = create_test_datatree()

def multiply(ds, **kwargs):
return kwargs.pop("times") * ds

expected = create_test_datatree(modify=lambda ds: 10.0 * ds)
result_tree = dt.map_over_datasets(multiply, kwargs=dict(times=10.0))
assert_equal(result_tree, expected)

def test_discard_ancestry(self, create_test_datatree):
# Check for datatree GH issue https://github.com/xarray-contrib/datatree/issues/48
dt = create_test_datatree()
Expand Down

0 comments on commit 1189240

Please sign in to comment.