Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: adapt to new Task spec in dask, now used in blockwise #556

Merged
merged 29 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
6158208
fix: adapt to new Task spec in dask, now used in blockwise
lgray Dec 4, 2024
d8cc3e4
drop py3.8 from tests
lgray Dec 4, 2024
7312b31
ah, we need version check logic instead, great...
lgray Dec 4, 2024
f3461bb
whitespace
lgray Dec 4, 2024
e100c45
guard against missing _task_spec and Task classes in older dask
lgray Dec 4, 2024
fc46473
adjust min version requirements
lgray Dec 4, 2024
341500a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 4, 2024
8780dae
commas are good things
lgray Dec 4, 2024
6f96f4f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 4, 2024
d7b3d9f
would you kindly...
lgray Dec 4, 2024
883b23e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 4, 2024
c25e1f6
appease mypy's insatiable lust for perfect correctness
lgray Dec 4, 2024
7c2174c
cleaner way of dealing with it
lgray Dec 4, 2024
2e6da4e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 4, 2024
06a475c
forgot to update names
lgray Dec 4, 2024
ef10929
mypy
lgray Dec 4, 2024
954f6e6
missing types
lgray Dec 4, 2024
78d6503
mypy...
lgray Dec 4, 2024
3a88ec6
... mypy
lgray Dec 4, 2024
01b0a45
just ignore it all
lgray Dec 4, 2024
536dfc4
update rewrite_layer_chains with new dask Tasks
pfackeldey Dec 6, 2024
90cf27b
fix args preparation for rewrite_layer_chains and update _mock_output()
pfackeldey Dec 6, 2024
bdb42e7
use TaskRef to pass test - may not be correct
lgray Dec 7, 2024
91e4890
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2024
6420dcc
import _dask_uses_tasks
lgray Dec 7, 2024
cee57c8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2024
57a2472
don't import in the function call
lgray Dec 7, 2024
6a19642
better check for dask._task_spec
lgray Dec 7, 2024
5d01fef
avoid imports in loops
lgray Dec 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ classifiers = [
]
dependencies = [
"awkward >=2.5.1",
"dask >=2023.04.0",
"dask >=2024.12.0;python_version>'3.9'",
"dask >=2023.04.0;python_version<'3.10'",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In live discussions, we were tending towards dropping backward compatibility here, which means dropping py3.9 support (which dask and numpy already have). Users of py3.9 will not have hit the original problem, since the new dask was not released for them.

This would also save about half the lOC in this PR.

"cachetools",
"typing_extensions >=4.8.0",
]
Expand Down
2 changes: 2 additions & 0 deletions src/dask_awkward/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
ImplementsIOFunction,
ImplementsProjection,
IOFunctionWithMocking,
_dask_uses_tasks,
io_func_implements_projection,
)

Expand All @@ -18,4 +19,5 @@
"ImplementsIOFunction",
"IOFunctionWithMocking",
"io_func_implements_projection",
"_dask_uses_tasks",
)
29 changes: 21 additions & 8 deletions src/dask_awkward/layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,20 @@
from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, Union, cast

import dask

_dask_uses_tasks = hasattr(dask, "_task_spec")

from dask.blockwise import Blockwise, BlockwiseDepDict, blockwise_token
from dask.highlevelgraph import MaterializedLayer
from dask.layers import DataFrameTreeReduction
from typing_extensions import TypeAlias

from dask_awkward.utils import LazyInputsDict

if _dask_uses_tasks:
from dask._task_spec import Task, TaskRef

if TYPE_CHECKING:
from awkward import Array as AwkwardArray
from awkward._nplikes.typetracer import TypeTracerReport
Expand Down Expand Up @@ -160,14 +167,20 @@ def __init__(
produces_tasks=self.produces_tasks,
)

super().__init__(
output=self.name,
output_indices="i",
dsk={name: (self.io_func, blockwise_token(0))},
indices=[(io_arg_map, "i")],
numblocks={},
annotations=None,
)
super_kwargs: dict[str, Any] = {
"output": self.name,
"output_indices": "i",
"indices": [(io_arg_map, "i")],
"numblocks": {},
"annotations": None,
}

if _dask_uses_tasks:
super_kwargs["task"] = Task(name, self.io_func, TaskRef(blockwise_token(0)))
else:
super_kwargs["dsk"] = {name: (self.io_func, blockwise_token(0))}

super().__init__(**super_kwargs)

def __repr__(self) -> str:
return f"AwkwardInputLayer<{self.output}>"
Expand Down
14 changes: 12 additions & 2 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@
from dask.utils import OperatorMethodMixin as DaskOperatorMethodMixin
from dask.utils import funcname, is_arraylike, key_split

from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardMaterializedLayer
from dask_awkward.layers import (
AwkwardBlockwiseLayer,
AwkwardMaterializedLayer,
_dask_uses_tasks,
)
from dask_awkward.lib.optimize import all_optimizations
from dask_awkward.utils import (
ConcretizationTypeError,
Expand All @@ -57,6 +61,9 @@
is_empty_slice,
)

if _dask_uses_tasks:
from dask._task_spec import TaskRef

if TYPE_CHECKING:
from awkward.contents.content import Content
from awkward.forms.form import Form
Expand Down Expand Up @@ -1928,7 +1935,10 @@ def partitionwise_layer(
pairs.extend([arg.name, "i"])
numblocks[arg.name] = (1,)
elif isinstance(arg, Delayed):
pairs.extend([arg.key, None])
if _dask_uses_tasks:
pairs.extend([TaskRef(arg.key), None])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that's correct 👍

else:
pairs.extend([arg.key, None])
elif is_dask_collection(arg):
raise DaskAwkwardNotImplemented(
"Use of Array with other Dask collections is currently unsupported."
Expand Down
70 changes: 55 additions & 15 deletions src/dask_awkward/lib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,17 @@
from dask.highlevelgraph import HighLevelGraph
from dask.local import get_sync

from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardInputLayer
from dask_awkward.layers import (
AwkwardBlockwiseLayer,
AwkwardInputLayer,
_dask_uses_tasks,
)
from dask_awkward.lib.utils import typetracer_nochecks
from dask_awkward.utils import first

if _dask_uses_tasks:
from dask._task_spec import GraphNode, Task, TaskRef

if TYPE_CHECKING:
from awkward._nplikes.typetracer import TypeTracerReport
from dask.typing import Key
Expand Down Expand Up @@ -234,14 +241,23 @@ def _touch_all_data(*args, **kwargs):

def _mock_output(layer):
"""Update a layer to run the _touch_all_data."""
assert len(layer.dsk) == 1
if _dask_uses_tasks:
new_layer = copy.deepcopy(layer)
task = new_layer.task.copy()
Comment on lines +245 to +246
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My guess is that the task specific copy is not required after the deepcopy. I was already contemplating whether we should get rid of copy (because it is difficult to maintain / would require subclasses to overwrite it and we might want to make use of subclassing)

# replace the original function with _touch_all_data
# and keep the rest of the task the same
task.func = _touch_all_data
new_layer.task = task
return new_layer
else:
assert len(layer.dsk) == 1

new_layer = copy.deepcopy(layer)
mp = new_layer.dsk.copy()
for k in iter(mp.keys()):
mp[k] = (_touch_all_data,) + mp[k][1:]
new_layer.dsk = mp
return new_layer
new_layer = copy.deepcopy(layer)
mp = new_layer.dsk.copy()
for k in iter(mp.keys()):
mp[k] = (_touch_all_data,) + mp[k][1:]
new_layer.dsk = mp
return new_layer


@no_type_check
Expand Down Expand Up @@ -340,7 +356,10 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG
deps[outkey] = deps[chain[0]]
[deps.pop(ch) for ch in chain[:-1]]

subgraph = layer0.dsk.copy() # mypy: ignore
if _dask_uses_tasks:
all_tasks = [layer0.task]
else:
subgraph = layer0.dsk.copy()
indices = list(layer0.indices)
parent = chain[0]

Expand All @@ -349,14 +368,28 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG
layer = dsk.layers[chain_member]
for k in layer.io_deps: # mypy: ignore
outlayer.io_deps[k] = layer.io_deps[k]
func, *args = layer.dsk[chain_member] # mypy: ignore
args2 = _recursive_replace(args, layer, parent, indices)
subgraph[chain_member] = (func,) + tuple(args2)

if _dask_uses_tasks:
func = layer.task.func
args = [
arg.key if isinstance(arg, GraphNode) else arg
for arg in layer.task.args
]
# how to do this with `.substitute(...)`?
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this still an open question?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was unsure how do implement this with .substitute(). I used our internal function instead, but it would be nice to use .substitute if that does the same thing.

It's not a show stopper right now though.

args2 = _recursive_replace(args, layer, parent, indices)
all_tasks.append(Task(chain_member, func, *args2))
else:
func, *args = layer.dsk[chain_member] # mypy: ignore
args2 = _recursive_replace(args, layer, parent, indices)
subgraph[chain_member] = (func,) + tuple(args2)
parent = chain_member
outlayer.numblocks = {
i[0]: (numblocks,) for i in indices if i[1] is not None
} # mypy: ignore
outlayer.dsk = subgraph # mypy: ignore
if _dask_uses_tasks:
outlayer.task = Task.fuse(*all_tasks)
else:
outlayer.dsk = subgraph # mypy: ignore
if hasattr(outlayer, "_dims"):
del outlayer._dims
outlayer.indices = tuple( # mypy: ignore
Expand All @@ -379,11 +412,18 @@ def _recursive_replace(args, layer, parent, indices):
args2.append(layer.indices[ind][0])
elif layer.indices[ind][0] == parent:
# arg refers to output of previous layer
args2.append(parent)
if _dask_uses_tasks:
args2.append(TaskRef(parent))
else:
args2.append(parent)
else:
# arg refers to things defined in io_deps
indices.append(layer.indices[ind])
args2.append(f"__dask_blockwise__{len(indices) - 1}")
arg2 = f"__dask_blockwise__{len(indices) - 1}"
if _dask_uses_tasks:
args2.append(TaskRef(arg2))
else:
args2.append(arg2)
elif isinstance(arg, list):
args2.append(_recursive_replace(arg, layer, parent, indices))
elif isinstance(arg, tuple):
Expand Down
8 changes: 6 additions & 2 deletions tests/test_io_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest

import dask_awkward as dak
from dask_awkward.layers import _dask_uses_tasks
from dask_awkward.lib.core import Array
from dask_awkward.lib.optimize import optimize as dak_optimize
from dask_awkward.lib.testutils import assert_eq
Expand Down Expand Up @@ -94,8 +95,11 @@ def input_layer_array_partition0(collection: Array) -> ak.Array:
optimized_hlg = dak_optimize(collection.dask, collection.keys) # type: ignore
layers = list(optimized_hlg.layers) # type: ignore
layer_name = [name for name in layers if name.startswith("from-json")][0]
sgc, arg = optimized_hlg[(layer_name, 0)]
array = sgc.dsk[layer_name][0](arg)
if _dask_uses_tasks:
array = optimized_hlg[(layer_name, 0)]()
else:
sgc, arg = optimized_hlg[(layer_name, 0)]
array = sgc.dsk[layer_name][0](arg)
return array


Expand Down
Loading