-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: adapt to new Task spec in dask, now used in blockwise #556
Changes from all commits
6158208
d8cc3e4
7312b31
f3461bb
e100c45
fc46473
341500a
8780dae
6f96f4f
d7b3d9f
883b23e
c25e1f6
7c2174c
2e6da4e
06a475c
ef10929
954f6e6
78d6503
3a88ec6
01b0a45
536dfc4
90cf27b
bdb42e7
91e4890
6420dcc
cee57c8
57a2472
6a19642
5d01fef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,7 +45,11 @@ | |
from dask.utils import OperatorMethodMixin as DaskOperatorMethodMixin | ||
from dask.utils import funcname, is_arraylike, key_split | ||
|
||
from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardMaterializedLayer | ||
from dask_awkward.layers import ( | ||
AwkwardBlockwiseLayer, | ||
AwkwardMaterializedLayer, | ||
_dask_uses_tasks, | ||
) | ||
from dask_awkward.lib.optimize import all_optimizations | ||
from dask_awkward.utils import ( | ||
ConcretizationTypeError, | ||
|
@@ -57,6 +61,9 @@ | |
is_empty_slice, | ||
) | ||
|
||
if _dask_uses_tasks: | ||
from dask._task_spec import TaskRef | ||
|
||
if TYPE_CHECKING: | ||
from awkward.contents.content import Content | ||
from awkward.forms.form import Form | ||
|
@@ -1928,7 +1935,10 @@ def partitionwise_layer( | |
pairs.extend([arg.name, "i"]) | ||
numblocks[arg.name] = (1,) | ||
elif isinstance(arg, Delayed): | ||
pairs.extend([arg.key, None]) | ||
if _dask_uses_tasks: | ||
pairs.extend([TaskRef(arg.key), None]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, that's correct 👍 |
||
else: | ||
pairs.extend([arg.key, None]) | ||
elif is_dask_collection(arg): | ||
raise DaskAwkwardNotImplemented( | ||
"Use of Array with other Dask collections is currently unsupported." | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,10 +13,17 @@ | |
from dask.highlevelgraph import HighLevelGraph | ||
from dask.local import get_sync | ||
|
||
from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardInputLayer | ||
from dask_awkward.layers import ( | ||
AwkwardBlockwiseLayer, | ||
AwkwardInputLayer, | ||
_dask_uses_tasks, | ||
) | ||
from dask_awkward.lib.utils import typetracer_nochecks | ||
from dask_awkward.utils import first | ||
|
||
if _dask_uses_tasks: | ||
from dask._task_spec import GraphNode, Task, TaskRef | ||
|
||
if TYPE_CHECKING: | ||
from awkward._nplikes.typetracer import TypeTracerReport | ||
from dask.typing import Key | ||
|
@@ -234,14 +241,23 @@ def _touch_all_data(*args, **kwargs): | |
|
||
def _mock_output(layer): | ||
"""Update a layer to run the _touch_all_data.""" | ||
assert len(layer.dsk) == 1 | ||
if _dask_uses_tasks: | ||
new_layer = copy.deepcopy(layer) | ||
task = new_layer.task.copy() | ||
Comment on lines
+245
to
+246
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My guess is that the task specific copy is not required after the deepcopy. I was already contemplating whether we should get rid of |
||
# replace the original function with _touch_all_data | ||
# and keep the rest of the task the same | ||
task.func = _touch_all_data | ||
new_layer.task = task | ||
return new_layer | ||
else: | ||
assert len(layer.dsk) == 1 | ||
|
||
new_layer = copy.deepcopy(layer) | ||
mp = new_layer.dsk.copy() | ||
for k in iter(mp.keys()): | ||
mp[k] = (_touch_all_data,) + mp[k][1:] | ||
new_layer.dsk = mp | ||
return new_layer | ||
new_layer = copy.deepcopy(layer) | ||
mp = new_layer.dsk.copy() | ||
for k in iter(mp.keys()): | ||
mp[k] = (_touch_all_data,) + mp[k][1:] | ||
new_layer.dsk = mp | ||
return new_layer | ||
|
||
|
||
@no_type_check | ||
|
@@ -340,7 +356,10 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG | |
deps[outkey] = deps[chain[0]] | ||
[deps.pop(ch) for ch in chain[:-1]] | ||
|
||
subgraph = layer0.dsk.copy() # mypy: ignore | ||
if _dask_uses_tasks: | ||
all_tasks = [layer0.task] | ||
else: | ||
subgraph = layer0.dsk.copy() | ||
indices = list(layer0.indices) | ||
parent = chain[0] | ||
|
||
|
@@ -349,14 +368,28 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG | |
layer = dsk.layers[chain_member] | ||
for k in layer.io_deps: # mypy: ignore | ||
outlayer.io_deps[k] = layer.io_deps[k] | ||
func, *args = layer.dsk[chain_member] # mypy: ignore | ||
args2 = _recursive_replace(args, layer, parent, indices) | ||
subgraph[chain_member] = (func,) + tuple(args2) | ||
|
||
if _dask_uses_tasks: | ||
func = layer.task.func | ||
args = [ | ||
arg.key if isinstance(arg, GraphNode) else arg | ||
for arg in layer.task.args | ||
] | ||
# how to do this with `.substitute(...)`? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this still an open question? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I was unsure how do implement this with It's not a show stopper right now though. |
||
args2 = _recursive_replace(args, layer, parent, indices) | ||
all_tasks.append(Task(chain_member, func, *args2)) | ||
else: | ||
func, *args = layer.dsk[chain_member] # mypy: ignore | ||
args2 = _recursive_replace(args, layer, parent, indices) | ||
subgraph[chain_member] = (func,) + tuple(args2) | ||
parent = chain_member | ||
outlayer.numblocks = { | ||
i[0]: (numblocks,) for i in indices if i[1] is not None | ||
} # mypy: ignore | ||
outlayer.dsk = subgraph # mypy: ignore | ||
if _dask_uses_tasks: | ||
outlayer.task = Task.fuse(*all_tasks) | ||
else: | ||
outlayer.dsk = subgraph # mypy: ignore | ||
if hasattr(outlayer, "_dims"): | ||
del outlayer._dims | ||
outlayer.indices = tuple( # mypy: ignore | ||
|
@@ -379,11 +412,18 @@ def _recursive_replace(args, layer, parent, indices): | |
args2.append(layer.indices[ind][0]) | ||
elif layer.indices[ind][0] == parent: | ||
# arg refers to output of previous layer | ||
args2.append(parent) | ||
if _dask_uses_tasks: | ||
args2.append(TaskRef(parent)) | ||
else: | ||
args2.append(parent) | ||
else: | ||
# arg refers to things defined in io_deps | ||
indices.append(layer.indices[ind]) | ||
args2.append(f"__dask_blockwise__{len(indices) - 1}") | ||
arg2 = f"__dask_blockwise__{len(indices) - 1}" | ||
if _dask_uses_tasks: | ||
args2.append(TaskRef(arg2)) | ||
else: | ||
args2.append(arg2) | ||
elif isinstance(arg, list): | ||
args2.append(_recursive_replace(arg, layer, parent, indices)) | ||
elif isinstance(arg, tuple): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In live discussions, we were tending towards dropping backward compatibility here, which means dropping py3.9 support (which dask and numpy already have). Users of py3.9 will not have hit the original problem, since the new dask was not released for them.
This would also save about half the lOC in this PR.