Skip to content

Commit

Permalink
feat: make from_awkward compatible with necessary_columns (#453)
Browse files Browse the repository at this point in the history
  • Loading branch information
douglasdavis authored Jan 24, 2024
1 parent edd5bc0 commit d3b6208
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 12 deletions.
36 changes: 25 additions & 11 deletions src/dask_awkward/lib/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
new_array_object,
typetracer_array,
)
from dask_awkward.lib.io.columnar import ColumnProjectionMixin
from dask_awkward.utils import first, second

if TYPE_CHECKING:
Expand All @@ -45,12 +46,29 @@
logger = logging.getLogger(__name__)


class _FromAwkwardFn:
def __init__(self, arr: ak.Array) -> None:
class FromAwkwardFn(ColumnProjectionMixin):
def __init__(
self,
arr: ak.Array,
behavior: Mapping | None = None,
attrs: Mapping[str, Any] | None = None,
) -> None:
self.arr = arr
self.form = arr.layout.form
self.behavior = behavior
self.attrs = attrs

def __call__(self, start: int, stop: int, **kwargs: Any) -> ak.Array:
return cast(ak.Array, self.arr[start:stop])
@property
def use_optimization(self):
return True

def __call__(self, *args, **kwargs):
start, stop = args[0]
arr = cast(ak.Array, self.arr[start:stop])
return ak.Array(arr, behavior=self.behavior, attrs=self.attrs)

def project_columns(self, columns):
return type(self)(self.arr, self.behavior, self.attrs)


def from_awkward(
Expand Down Expand Up @@ -96,21 +114,17 @@ def from_awkward(
else:
chunksize = int(math.ceil(nrows / npartitions))
locs = tuple(list(range(0, nrows, chunksize)) + [nrows])
starts = locs[:-1]
stops = locs[1:]
starts_stops = list(zip(locs[:-1], locs[1:]))
meta = typetracer_array(source)
return cast(
Array,
from_map(
_FromAwkwardFn(source),
starts,
stops,
FromAwkwardFn(source, behavior=behavior),
starts_stops,
label=label or "from-awkward",
token=tokenize(source, npartitions),
divisions=locs,
meta=meta,
behavior=behavior,
attrs=attrs,
),
)

Expand Down
23 changes: 23 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dask_awkward.lib.core import typetracer_array
from dask_awkward.lib.io.io import _bytes_with_sample, from_map
from dask_awkward.lib.testutils import assert_eq
from dask_awkward.utils import first


def test_to_and_from_dask_array(daa: dak.Array) -> None:
Expand Down Expand Up @@ -415,3 +416,25 @@ def test_from_map_fail_with_callbacks():
_, rep = dask.compute(array, report)

assert "OSError" in rep.exception.tolist()


def test_from_awkward_necessary_columns(caa):
behavior = {}

@ak.mixin_class(behavior)
class Point:
@property
def xsq(self):
return self.x * self.x

@ak.mixin_class_method(np.abs)
def point_abs(self):
return np.sqrt(self.x**2 + self.y**2)

caa = ak.with_name(caa.points, name="Point", behavior=behavior)
daa = dak.from_awkward(caa, npartitions=2, behavior=behavior)
assert_eq(caa.xsq, daa.xsq)
assert set(first(dak.necessary_columns(daa.xsq).items())[1]) == {"x"}
assert set(first(dak.necessary_columns(daa).items())[1]) == {"x", "y"}
assert set(first(dak.necessary_columns(np.abs(daa)).items())[1]) == {"x", "y"}
assert_eq(np.abs(caa), np.abs(daa))
2 changes: 1 addition & 1 deletion tests/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_concatenate_axis_0_logical_different(daa):
result = dak.concatenate([daa, empty_dak_array], axis=0)

buffers_report = dak.report_necessary_buffers(result.points.x)
assert len(buffers_report) == 1
assert len(buffers_report) == 2

buffers = next(iter(buffers_report.values()))
assert buffers.data_and_shape == frozenset(
Expand Down

0 comments on commit d3b6208

Please sign in to comment.