From be4e7ba68847afd46529567892d0106fde49dc69 Mon Sep 17 00:00:00 2001 From: Lindsey Gray Date: Mon, 16 Dec 2024 19:59:42 -0600 Subject: [PATCH] correctly deal with non-collection kwargs for already flat deps --- src/dask_awkward/lib/core.py | 17 +++++++++++++---- src/dask_awkward/lib/optimize.py | 13 ++++++++++++- tests/test_core.py | 9 +++++++++ 3 files changed, 34 insertions(+), 5 deletions(-) diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index aa642bcf..0a046665 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -2171,7 +2171,7 @@ def map_partitions( message += f"- {type(arg)}" raise TypeError(message) - if len(kwargs) == 0: + if len(kwarg_flat_deps) == 0: non_traversed_deps, _ = unpack_collections(*args, traverse=False) if len(flat_deps) == len(non_traversed_deps) and all( id(traversed_dep) == id(non_traversed_dep) @@ -2184,6 +2184,7 @@ def map_partitions( token=token, meta=meta, output_divisions=output_divisions, + **kwargs, ) arg_flat_deps_expanded = [] @@ -2556,11 +2557,19 @@ def to_length_zero_arrays(objects: Sequence[Any]) -> tuple[Any, ...]: return tuple(map(length_zero_array_or_identity, objects)) -def map_meta(fn: Callable | ArgsKwargsPackedFunction, *deps: Any) -> ak.Array | None: - # NOTE: fn is assumed to be a *packed* function +def map_meta( + fn: Callable | ArgsKwargsPackedFunction, *deps: Any, **kwargs: Any +) -> ak.Array | None: + # NOTE: fn to be a *packed* function (so flat deps or ArgsKwargsPackedFunction) + # if ArgsKwargsPackedFunction we do not allow kwargs # as defined up in map_partitions. be careful! + if isinstance(fn, ArgsKwargsPackedFunction) and len(kwargs) > 0: + raise ValueError("ArgsKwargsPackedFunctions may not have additional kwargs!") try: - meta = fn(*to_meta(deps)) + if isinstance(fn, ArgsKwargsPackedFunction): + meta = fn(*to_meta(deps)) + else: + meta = fn(*to_meta(deps), **kwargs) return meta except Exception as err: # if compute-unknown-meta is False then we don't care about diff --git a/src/dask_awkward/lib/optimize.py b/src/dask_awkward/lib/optimize.py index 4b9dd6cf..041b4981 100644 --- a/src/dask_awkward/lib/optimize.py +++ b/src/dask_awkward/lib/optimize.py @@ -375,9 +375,20 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG arg.key if isinstance(arg, GraphNode) else arg for arg in layer.task.args ] + kwargs = { + k: v.key if isinstance(v, GraphNode) else v + for k, v in layer.task.kwargs.items() + } # how to do this with `.substitute(...)`? args2 = _recursive_replace(args, layer, parent, indices) - all_tasks.append(Task(chain_member, func, *args2)) + kwargs2 = { + k: v + for k, v in zip( + kwargs.keys(), + _recursive_replace(kwargs.values(), layer, parent, indices), + ) + } + all_tasks.append(Task(chain_member, func, *args2, **kwargs2)) else: func, *args = layer.dsk[chain_member] # mypy: ignore args2 = _recursive_replace(args, layer, parent, indices) diff --git a/tests/test_core.py b/tests/test_core.py index 3ebf0777..98b312ef 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -806,6 +806,11 @@ def test_map_partitions_args_and_kwargs_have_collection(): zc = my_power(xc, kwarg_y=yc) zl = dak.map_partitions(my_power, xl, kwarg_y=yl) + # kwargs that contain collections should be wrapped + assert isinstance( + zl.dask.layers[zl.name].task.func, dak.lib.core.ArgsKwargsPackedFunction + ) + assert_eq(zc, zl) zd = structured_function(inputs={"x": xc, "y": xc, "z": yc}) @@ -830,6 +835,9 @@ def test_map_partitions_args_and_kwargs_have_collection(): zg = my_power(xc, kwarg_y=2.0) zp = dak.map_partitions(my_power, xl, kwarg_y=2.0) + # this invocation of my_power shouldn't be wrapped, no collections + assert zp.dask.layers[zp.name].task.func is my_power + assert_eq(zg, zp) a = ak.Array( @@ -860,6 +868,7 @@ def test_map_partitions_args_and_kwargs_have_collection(): ccc=cc, ddd=dd, ) + assert_eq(res1, res2)