Skip to content

Commit

Permalink
fix: properly handle mapping version of (arg)cartesian for at least d…
Browse files Browse the repository at this point in the history
…icts (#443)

* properly handle mapping version of (arg)cartesian for at least dicts

* naming

Co-authored-by: Angus Hollands <[email protected]>

* naming

Co-authored-by: Angus Hollands <[email protected]>

* naming

Co-authored-by: Angus Hollands <[email protected]>

* naming

Co-authored-by: Angus Hollands <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Angus Hollands <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 5, 2024
1 parent ea7a8be commit 7427bc1
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
28 changes: 17 additions & 11 deletions src/dask_awkward/lib/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
from awkward.types.type import Type
from awkward.typetracer import create_unknown_scalar, is_unknown_scalar
from dask.base import is_dask_collection, tokenize
from dask.base import is_dask_collection, tokenize, unpack_collections
from dask.highlevelgraph import HighLevelGraph

from dask_awkward.layers import AwkwardMaterializedLayer
Expand Down Expand Up @@ -77,15 +77,12 @@


class _ArgCartesianFn:
def __init__(self, **kwargs):
def __init__(self, repacker, **kwargs):
self.repacker = repacker
self.kwargs = kwargs

def __call__(self, *arrays):
# FIXME: with proper typetracer/form rehydration support we
# should not need to manually touch this when it's a
# typetracer
arrays = [ak.typetracer.touch_data(a) for a in arrays]
return ak.argcartesian(arrays, **self.kwargs)
return ak.argcartesian(self.repacker(arrays)[0], **self.kwargs)


@borrow_docstring(ak.argcartesian)
Expand All @@ -104,7 +101,9 @@ def argcartesian(

# FIXME: resolve negative axis
if axis >= 1:
arrays_unpacked, repacker = unpack_collections(arrays, traverse=True)
fn = _ArgCartesianFn(
repacker,
axis=axis,
nested=nested,
parameters=parameters,
Expand All @@ -113,7 +112,9 @@ def argcartesian(
behavior=behavior,
attrs=attrs,
)
return map_partitions(fn, *arrays, label="argcartesian", output_divisions=1)
return map_partitions(
fn, *arrays_unpacked, label="argcartesian", output_divisions=1
)
raise DaskAwkwardNotImplemented("TODO")


Expand Down Expand Up @@ -238,11 +239,12 @@ def broadcast_arrays(


class _CartesianFn:
def __init__(self, **kwargs):
def __init__(self, repacker, **kwargs):
self.repacker = repacker
self.kwargs = kwargs

def __call__(self, *arrays):
return ak.cartesian(list(arrays), **self.kwargs)
return ak.cartesian(self.repacker(arrays)[0], **self.kwargs)


@borrow_docstring(ak.cartesian)
Expand All @@ -259,7 +261,9 @@ def cartesian(
if not highlevel:
raise ValueError("Only highlevel=True is supported")
if axis == 1:
arrays_unpacked, repacker = unpack_collections(arrays, traverse=True)
fn = _CartesianFn(
repacker,
axis=axis,
nested=nested,
parameters=parameters,
Expand All @@ -268,7 +272,9 @@ def cartesian(
behavior=behavior,
attrs=attrs,
)
return map_partitions(fn, *arrays, label="cartesian", output_divisions=1)
return map_partitions(
fn, *arrays_unpacked, label="cartesian", output_divisions=1
)
raise DaskAwkwardNotImplemented("TODO")


Expand Down
8 changes: 8 additions & 0 deletions tests/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def test_cartesian(caa: ak.Array, daa: dak.Array) -> None:
cz = ak.cartesian([ca1, ca2], axis=1)
assert_eq(dz, cz)

dz = dak.cartesian({"xx": da1, "yy": da2}, axis=1)
cz = ak.cartesian({"xx": ca1, "yy": ca2}, axis=1)
assert_eq(dz, cz)


def test_argcartesian(caa: ak.Array, daa: dak.Array) -> None:
da1 = daa["points", "x"]
Expand All @@ -115,6 +119,10 @@ def test_argcartesian(caa: ak.Array, daa: dak.Array) -> None:
cz = ak.argcartesian([ca1, ca2], axis=1)
assert_eq(dz, cz)

dz = dak.argcartesian({"xx": da1, "yy": da2}, axis=1)
cz = ak.argcartesian({"xx": ca1, "yy": ca2}, axis=1)
assert_eq(dz, cz)


def test_ones_like(caa: ak.Array, daa: dak.Array) -> None:
da1 = dak.ones_like(daa.points.x)
Expand Down

0 comments on commit 7427bc1

Please sign in to comment.