diff --git a/pyproject.toml b/pyproject.toml index d5514159..fee69722 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ classifiers = [ "Topic :: Software Development", ] dependencies = [ - "awkward >=2.5.1", + "awkward >=2.7.4", "dask >=2023.04.0", "cachetools", "typing_extensions >=4.8.0", diff --git a/src/dask_awkward/__init__.py b/src/dask_awkward/__init__.py index 34b5c4f5..ce43b7da 100644 --- a/src/dask_awkward/__init__.py +++ b/src/dask_awkward/__init__.py @@ -104,5 +104,6 @@ without_parameters, zeros_like, zip, + zip_no_broadcast, ) from dask_awkward.version import __version__ diff --git a/src/dask_awkward/lib/__init__.py b/src/dask_awkward/lib/__init__.py index 74d16d6c..31575c23 100644 --- a/src/dask_awkward/lib/__init__.py +++ b/src/dask_awkward/lib/__init__.py @@ -90,4 +90,5 @@ without_parameters, zeros_like, zip, + zip_no_broadcast, ) diff --git a/src/dask_awkward/lib/io/io.py b/src/dask_awkward/lib/io/io.py index 609a1376..f9239ca5 100644 --- a/src/dask_awkward/lib/io/io.py +++ b/src/dask_awkward/lib/io/io.py @@ -468,7 +468,7 @@ def to_dataframe( """ import dask from dask.dataframe import DataFrame as DaskDataFrame - from dask.dataframe.core import new_dd_object # type: ignore + from dask.dataframe.core import new_dd_object if parse_version(dask.__version__) >= parse_version("2025"): raise NotImplementedError( diff --git a/src/dask_awkward/lib/structure.py b/src/dask_awkward/lib/structure.py index 6b70c8b9..86d061c3 100644 --- a/src/dask_awkward/lib/structure.py +++ b/src/dask_awkward/lib/structure.py @@ -76,6 +76,7 @@ "without_parameters", "zeros_like", "zip", + "zip_no_broadcast", ) @@ -1343,6 +1344,88 @@ def zip( ) +class _ZipNoBroadcastDictInputFn: + def __init__(self, keys: Sequence[str], **kwargs: Any) -> None: + self.keys = keys + self.kwargs = kwargs + + def __call__(self, *parts: ak.Array) -> ak.Array: + return ak.zip_no_broadcast( + {k: p for k, p in builtins.zip(self.keys, list(parts))}, + **self.kwargs, + ) + + +class _ZipNoBroadcastListInputFn: + def __init__(self, **kwargs: Any) -> None: + self.kwargs = kwargs + + def __call__(self, *parts: Any) -> ak.Array: + return ak.zip_no_broadcast(list(parts), **self.kwargs) + + +@borrow_docstring(ak.zip_no_broadcast) +def zip_no_broadcast( + arrays: Sequence[Array] | Mapping[str, Array], + parameters: Mapping[str, Any] | None = None, + with_name: str | None = None, + highlevel: bool = True, + behavior: Mapping | None = None, + attrs: Mapping[str, Any] | None = None, +) -> Array: + if not highlevel: + raise ValueError("Only highlevel=True is supported") + + if isinstance(arrays, Mapping): + keys, colls, metadict = [], [], {} + for k, coll in arrays.items(): + keys.append(k) + colls.append(coll) + metadict[k] = coll._meta + + meta = ak.zip_no_broadcast( + metadict, + parameters=parameters, + with_name=with_name, + highlevel=highlevel, + behavior=behavior, + attrs=attrs, + ) + + return map_partitions( + _ZipNoBroadcastDictInputFn( + keys, + parameters=parameters, + with_name=with_name, + highlevel=highlevel, + behavior=behavior, + attrs=attrs, + ), + *colls, + label="zip_no_broadcast", + meta=meta, + ) + + elif isinstance(arrays, Sequence): + fn = _ZipNoBroadcastListInputFn( + parameters=parameters, + with_name=with_name, + highlevel=highlevel, + behavior=behavior, + attrs=attrs, + ) + return map_partitions( + fn, + *arrays, + label="zip_no_broadcast", + ) + + else: + raise DaskAwkwardNotImplemented( + "only mappings or sequences are supported by dak.zip_no_broadcast (e.g. dict, list, or tuple)" + ) + + def _repartition_func(*stuff): import builtins diff --git a/tests/test_structure.py b/tests/test_structure.py index 9f86b747..91341b65 100644 --- a/tests/test_structure.py +++ b/tests/test_structure.py @@ -94,6 +94,48 @@ def test_zip_bad_input(daa: dak.Array) -> None: dak.zip(gd) +def test_zip_no_broadcast_dict_input(caa: ak.Array, daa: dak.Array) -> None: + da1 = daa["points"]["x"] + da2 = daa["points"]["x"] + ca1 = caa["points"]["x"] + ca2 = caa["points"]["x"] + + da_z = dak.zip_no_broadcast({"a": da1, "b": da2}) + ca_z = ak.zip_no_broadcast({"a": ca1, "b": ca2}) + assert_eq(da_z, ca_z) + + +def test_zip_no_broadcast_list_input(caa: ak.Array, daa: dak.Array) -> None: + da1 = daa.points.x + ca1 = caa.points.x + dz1 = dak.zip_no_broadcast([da1, da1]) + cz1 = ak.zip_no_broadcast([ca1, ca1]) + assert_eq(dz1, cz1) + dz2 = dak.zip_no_broadcast([da1, da1, da1]) + cz2 = ak.zip_no_broadcast([ca1, ca1, ca1]) + assert_eq(dz2, cz2) + + +def test_zip_no_broadcast_tuple_input(caa: ak.Array, daa: dak.Array) -> None: + da1 = daa.points.x + ca1 = caa.points.x + dz1 = dak.zip_no_broadcast((da1, da1)) + cz1 = ak.zip_no_broadcast((ca1, ca1)) + assert_eq(dz1, cz1) + dz2 = dak.zip_no_broadcast((da1, da1, da1)) + cz2 = ak.zip_no_broadcast((ca1, ca1, ca1)) + assert_eq(dz2, cz2) + + +def test_zip_no_broadcast_bad_input(daa: dak.Array) -> None: + da1 = daa.points.x + gd = (x for x in (da1, da1)) + with pytest.raises( + DaskAwkwardNotImplemented, match="only mappings or sequences are supported" + ): + dak.zip_no_broadcast(gd) + + def test_cartesian(caa: ak.Array, daa: dak.Array) -> None: da1 = daa["points", "x"] da2 = daa["points", "y"]