Skip to content

Commit

Permalink
Merge pull request #566 from dask-contrib/tree-reduction-fix
Browse files Browse the repository at this point in the history
fix: DataFrameTreeReduction is no longer part of dask
  • Loading branch information
martindurant authored Jan 27, 2025
2 parents 5e431bc + 8f40eac commit c3c2c4e
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 5 deletions.
268 changes: 266 additions & 2 deletions src/dask_awkward/layers/layers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from __future__ import annotations

import copy
import math
import operator
from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, Union, cast

import dask
import toolz
from dask.blockwise import Blockwise, BlockwiseDepDict, blockwise_token
from dask.highlevelgraph import MaterializedLayer
from dask.layers import DataFrameTreeReduction
from dask.layers import Layer
from typing_extensions import TypeAlias

from dask_awkward.utils import LazyInputsDict
Expand Down Expand Up @@ -366,7 +369,268 @@ def mock(self) -> MaterializedLayer:
return MaterializedLayer({(name, 0): task})


class AwkwardTreeReductionLayer(DataFrameTreeReduction):
class AwkwardTreeReductionLayer(Layer):
"""Awkward Tree-Reduction Layer
Parameters
----------
name : str
Name to use for the constructed layer.
name_input : str
Name of the input layer that is being reduced.
npartitions_input : str
Number of partitions in the input layer.
concat_func : callable
Function used by each tree node to reduce a list of inputs
into a single output value. This function must accept only
a list as its first positional argument.
tree_node_func : callable
Function used on the output of ``concat_func`` in each tree
node. This function must accept the output of ``concat_func``
as its first positional argument.
finalize_func : callable, optional
Function used in place of ``tree_node_func`` on the final tree
node(s) to produce the final output for each split. By default,
``tree_node_func`` will be used.
split_every : int, optional
This argument specifies the maximum number of input nodes
to be handled by any one task in the tree. Defaults to 32.
split_out : int, optional
This argument specifies the number of output nodes in the
reduction tree. If ``split_out`` is set to an integer >=1, the
input tasks must contain data that can be indexed by a ``getitem``
operation with a key in the range ``[0, split_out)``.
output_partitions : list, optional
List of required output partitions. This parameter is used
internally by Dask for high-level culling.
tree_node_name : str, optional
Name to use for intermediate tree-node tasks.
"""

name: str
name_input: str
npartitions_input: int
concat_func: Callable
tree_node_func: Callable
finalize_func: Callable | None
split_every: int
split_out: int
output_partitions: list[int]
tree_node_name: str
widths: list[int]
height: int

def __init__(
self,
name: str,
name_input: str,
npartitions_input: int,
concat_func: Callable,
tree_node_func: Callable,
finalize_func: Callable | None = None,
split_every: int = 32,
split_out: int | None = None,
output_partitions: list[int] | None = None,
tree_node_name: str | None = None,
annotations: dict[str, Any] | None = None,
):
super().__init__(annotations=annotations)
self.name = name
self.name_input = name_input
self.npartitions_input = npartitions_input
self.concat_func = concat_func
self.tree_node_func = tree_node_func
self.finalize_func = finalize_func
self.split_every = split_every
self.split_out = split_out # type: ignore
self.output_partitions = (
list(range(self.split_out or 1))
if output_partitions is None
else output_partitions
)
self.tree_node_name = tree_node_name or "tree_node-" + self.name

# Calculate tree widths and height
# (Used to get output keys without materializing)
parts = self.npartitions_input
self.widths = [parts]
while parts > 1:
parts = math.ceil(parts / self.split_every)
self.widths.append(int(parts))
self.height = len(self.widths)

def _make_key(self, *name_parts, split=0):
# Helper function construct a key
# with a "split" element when
# bool(split_out) is True
return name_parts + (split,) if self.split_out else name_parts

def _define_task(self, input_keys, final_task=False):
# Define nested concatenation and func task
if final_task and self.finalize_func:
outer_func = self.finalize_func
else:
outer_func = self.tree_node_func
return (toolz.pipe, input_keys, self.concat_func, outer_func)

def _construct_graph(self):
"""Construct graph for a tree reduction."""

dsk = {}
if not self.output_partitions:
return dsk

# Deal with `bool(split_out) == True`.
# These cases require that the input tasks
# return a type that enables getitem operation
# with indices: [0, split_out)
# Therefore, we must add "getitem" tasks to
# select the appropriate element for each split
name_input_use = self.name_input
if self.split_out:
name_input_use += "-split"
for s in self.output_partitions:
for p in range(self.npartitions_input):
dsk[self._make_key(name_input_use, p, split=s)] = (
operator.getitem,
(self.name_input, p),
s,
)

if self.height >= 2:
# Loop over output splits
for s in self.output_partitions:
# Loop over reduction levels
for depth in range(1, self.height):
# Loop over reduction groups
for group in range(self.widths[depth]):
# Calculate inputs for the current group
p_max = self.widths[depth - 1]
lstart = self.split_every * group
lstop = min(lstart + self.split_every, p_max)
if depth == 1:
# Input nodes are from input layer
input_keys = [
self._make_key(name_input_use, p, split=s)
for p in range(lstart, lstop)
]
else:
# Input nodes are tree-reduction nodes
input_keys = [
self._make_key(
self.tree_node_name, p, depth - 1, split=s
)
for p in range(lstart, lstop)
]

# Define task
if depth == self.height - 1:
# Final Node (Use fused `self.tree_finalize` task)
assert (
group == 0
), f"group = {group}, not 0 for final tree reduction task"
dsk[(self.name, s)] = self._define_task(
input_keys, final_task=True
)
else:
# Intermediate Node
dsk[
self._make_key(
self.tree_node_name, group, depth, split=s
)
] = self._define_task(input_keys, final_task=False)
else:
# Deal with single-partition case
for s in self.output_partitions:
input_keys = [self._make_key(name_input_use, 0, split=s)]
dsk[(self.name, s)] = self._define_task(input_keys, final_task=True)

return dsk

def __repr__(self):
return "DataFrameTreeReduction<name='{}', input_name={}, split_out={}>".format(
self.name, self.name_input, self.split_out
)

def _output_keys(self):
return {(self.name, s) for s in self.output_partitions}

def get_output_keys(self):
if hasattr(self, "_cached_output_keys"):
return self._cached_output_keys
else:
output_keys = self._output_keys()
self._cached_output_keys = output_keys
return self._cached_output_keys

def is_materialized(self):
return hasattr(self, "_cached_dict")

@property
def _dict(self):
"""Materialize full dict representation"""
if hasattr(self, "_cached_dict"):
return self._cached_dict
else:
dsk = self._construct_graph()
self._cached_dict = dsk
return self._cached_dict

def __getitem__(self, key):
return self._dict[key]

def __iter__(self):
return iter(self._dict)

def __len__(self):
# Start with "base" tree-reduction size
tree_size = (sum(self.widths[1:]) or 1) * (self.split_out or 1)
if self.split_out:
# Add on "split-*" tasks used for `getitem` ops
return tree_size + self.npartitions_input * len(self.output_partitions)
return tree_size

def _keys_to_output_partitions(self, keys):
"""Simple utility to convert keys to output partition indices."""
splits = set()
for key in keys:
try:
_name, _split = key
except ValueError:
continue
if _name != self.name:
continue
splits.add(_split)
return splits

def _cull(self, output_partitions):
return AwkwardTreeReductionLayer(
self.name,
self.name_input,
self.npartitions_input,
self.concat_func,
self.tree_node_func,
finalize_func=self.finalize_func,
split_every=self.split_every,
split_out=self.split_out,
output_partitions=output_partitions,
tree_node_name=self.tree_node_name,
annotations=self.annotations,
)

def cull(self, keys, all_keys):
"""Cull a DataFrameTreeReduction HighLevelGraph layer"""
deps = {
(self.name, 0): {
(self.name_input, i) for i in range(self.npartitions_input)
}
}
output_partitions = self._keys_to_output_partitions(keys)
if output_partitions != set(self.output_partitions):
culled_layer = self._cull(output_partitions)
return culled_layer, deps
else:
return self, deps

def mock(self) -> AwkwardTreeReductionLayer:
return AwkwardTreeReductionLayer(
name=self.name,
Expand Down
12 changes: 9 additions & 3 deletions src/dask_awkward/lib/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from dask.local import identity
from dask.utils import funcname, is_integer, parse_bytes
from fsspec.utils import infer_compression
from packaging.version import parse as parse_version

from dask_awkward.layers.layers import (
AwkwardBlockwiseLayer,
Expand All @@ -42,7 +43,7 @@
if TYPE_CHECKING:
from dask.array.core import Array as DaskArray
from dask.bag.core import Bag as DaskBag
from dask.dataframe.core import DataFrame as DaskDataFrame
from dask.dataframe import DataFrame as DaskDataFrame
from dask.delayed import Delayed
from fsspec.spec import AbstractFileSystem

Expand Down Expand Up @@ -466,9 +467,14 @@ def to_dataframe(
"""
import dask
from dask.dataframe.core import DataFrame as DaskDataFrame
from dask.dataframe.core import new_dd_object
from dask.dataframe import DataFrame as DaskDataFrame
from dask.dataframe.core import new_dd_object # type: ignore

if parse_version(dask.__version__) >= parse_version("2025"):
raise NotImplementedError(
"to_dataframe is broken with dask 2025.1.0"
"either wait for a fix, or downgrade your dask."
)
if optimize_graph:
(array,) = dask.optimize(array)
intermediate = map_partitions(
Expand Down
9 changes: 9 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dask.delayed import delayed
from fsspec.core import get_fs_token_paths
from numpy.typing import DTypeLike
from packaging.version import parse as parse_version

import dask_awkward as dak
from dask_awkward.lib.core import typetracer_array
Expand Down Expand Up @@ -246,6 +247,10 @@ def test_to_bag(daa, caa):
assert comprec.tolist() == entry.tolist()


@pytest.mark.skipif(
parse_version(dask.__version__) >= parse_version("2025"),
reason="dask.DataFrame constructor changed",
)
@pytest.mark.parametrize("optimize_graph", [True, False])
def test_to_dataframe(daa: dak.Array, caa: ak.Array, optimize_graph: bool) -> None:
pytest.importorskip("pandas")
Expand All @@ -261,6 +266,10 @@ def test_to_dataframe(daa: dak.Array, caa: ak.Array, optimize_graph: bool) -> No
assert_eq(dd, df, check_index=False)


@pytest.mark.skipif(
parse_version(dask.__version__) >= parse_version("2025"),
reason="dask.DataFrame constructor changed",
)
@pytest.mark.parametrize("optimize_graph", [True, False])
def test_to_dataframe_str(
daa_str: dak.Array, caa_str: ak.Array, optimize_graph: bool
Expand Down

0 comments on commit c3c2c4e

Please sign in to comment.