Skip to content

Commit

Permalink
Merge pull request #565 from pfackeldey/pfackeldey/manual_column_opti…
Browse files Browse the repository at this point in the history
…mization

feat: add possibility to manually perform the column projection
  • Loading branch information
pfackeldey authored Feb 20, 2025
2 parents ce1b71d + 22ed421 commit 9d6ccfd
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/dask_awkward/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

necessary_columns = report_necessary_columns # Export for backwards compatibility.

import dask_awkward.manual as manual
from dask_awkward.lib.io.io import (
from_awkward,
from_dask_array,
Expand Down
35 changes: 32 additions & 3 deletions src/dask_awkward/layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ def prepare_for_projection(self) -> tuple[AwkwardArray, TypeTracerReport, T]: ..

def project(self, report: TypeTracerReport, state: T) -> ImplementsIOFunction: ...

# `project_manually` is typically an alias to `project_columns`. Some IO functions
# might implement this method with a different name, because their respective file format
# uses different conventions. An example is ROOT, where the columns are called "keys".
# In this case, the method should be aliased to `project_keys`.
def project_manually(self, columns: frozenset[str]) -> ImplementsIOFunction: ...


class ImplementsNecessaryColumns(ImplementsProjection[T], Protocol):
def necessary_columns(
Expand Down Expand Up @@ -156,6 +162,7 @@ def __init__(
produces_tasks: bool = False,
creation_info: dict | None = None,
annotations: Mapping[str, Any] | None = None,
is_projectable: bool | None = None,
) -> None:
self.name = name
self.inputs = inputs
Expand All @@ -164,6 +171,7 @@ def __init__(
self.produces_tasks = produces_tasks
self.annotations = annotations
self.creation_info = creation_info
self._is_projectable = is_projectable

io_arg_map = BlockwiseDepDict(
mapping=LazyInputsDict(self.inputs), # type: ignore
Expand Down Expand Up @@ -191,9 +199,16 @@ def __repr__(self) -> str:
@property
def is_projectable(self) -> bool:
# isinstance(self.io_func, ImplementsProjection)
return (
io_func_implements_projection(self.io_func) and not self.has_been_unpickled
)
if self._is_projectable is None:
return (
io_func_implements_projection(self.io_func)
and not self.has_been_unpickled
)
return self._is_projectable

@is_projectable.setter
def is_projectable(self, value: bool) -> None:
self._is_projectable = value

@property
def is_mockable(self) -> bool:
Expand Down Expand Up @@ -302,6 +317,20 @@ def necessary_columns(self, report: TypeTracerReport, state: T) -> frozenset[str
report=report, state=state
)

def project_manually(self, columns: frozenset[str]) -> AwkwardInputLayer:
"""Project the necessary _columns_ to the AwkwardInputLayer."""
assert self.is_projectable
io_func = cast(ImplementsProjection, self.io_func).project_manually(columns)
return AwkwardInputLayer(
name=self.name,
inputs=self.inputs,
io_func=io_func,
label=self.label,
produces_tasks=self.produces_tasks,
creation_info=self.creation_info,
annotations=self.annotations,
)


class AwkwardMaterializedLayer(MaterializedLayer):
def __init__(
Expand Down
5 changes: 5 additions & 0 deletions src/dask_awkward/lib/io/columnar.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def attrs(self) -> dict | None: ...

def project_columns(self: T, columns: frozenset[str]) -> T: ...

def project_manually(self: T, columns: frozenset[str]) -> ImplementsIOFunction: ...

def __call__(self, *args, **kwargs): ...


Expand Down Expand Up @@ -176,3 +178,6 @@ def project(
return self

return self.project_columns(self.necessary_columns(report, state))

def project_manually(self: S, columns: frozenset[str]) -> ImplementsIOFunction:
return self.project_columns(columns)
2 changes: 1 addition & 1 deletion src/dask_awkward/lib/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def to_dataframe(
"""
import dask
from dask.dataframe import DataFrame as DaskDataFrame
from dask.dataframe.core import new_dd_object
from dask.dataframe.core import new_dd_object # type: ignore[attr-defined]

if parse_version(dask.__version__) >= parse_version("2025"):
raise NotImplementedError(
Expand Down
1 change: 1 addition & 0 deletions src/dask_awkward/manual/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from dask_awkward.manual.column_optimization import optimize_columns
58 changes: 58 additions & 0 deletions src/dask_awkward/manual/column_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from __future__ import annotations

from typing import cast

from dask.highlevelgraph import HighLevelGraph

from dask_awkward.layers.layers import AwkwardInputLayer
from dask_awkward.lib.core import Array


def optimize_columns(array: Array, columns: dict[str, frozenset[str]]) -> Array:
"""
Manually updates the AwkwardInputLayer(s) with the specified columns. This is useful
for tracing the necessary buffers for a given computation once, and then reusing the
typetracer reports to touch only the necessary columns for other datasets.
Calling this function will update the `AwkwardInputLayer`'s `necessary_columns` attribute,
i.e. pruning the columns that are not wanted. This replaces the automatic column optimization,
which is why one should be careful when using this function combined with `.compute(optimize_graph=True)`.
Parameters
----------
array : Array
The dask-awkward array to be optimized.
columns : dict[str, frozenset[str]]
The columns to be touched.
Returns
-------
Array
A new Dask-Awkward array with only the specified columns.
"""
if not isinstance(array, Array):
raise TypeError(
f"Expected `dak_array` to be of type `dask_awkward.Array`, got {type(array)}"
)

dsk = array.dask
layers = dict(dsk.layers)
deps = dict(dsk.dependencies)

for name, cols in columns.items():
io_layer = cast(AwkwardInputLayer, layers[name])
if not isinstance(io_layer, AwkwardInputLayer):
raise TypeError(
f"Expected layer {name} to be of type `dask_awkward.layers.AwkwardInputLayer`, got {type(io_layer)}"
)
projected_layer = io_layer.project_manually(columns=cols)

# explicitely disable 'project-ability' now, since we did this manually just now
# Is there a better way to do this? Because this disables the possibility to chain call `dak.manual.optimize_columns`
projected_layer.is_projectable = False

layers[name] = projected_layer

new_dsk = HighLevelGraph(layers, deps)
return array._rebuild(dsk=new_dsk)
53 changes: 53 additions & 0 deletions tests/test_manual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import annotations

import awkward as ak
import numpy as np
import pytest

import dask_awkward as dak


def test_optimize_columns():
pytest.importorskip("pyarrow")
pytest.importorskip("requests")
pytest.importorskip("aiohttp")

array = dak.from_parquet(
"https://github.com/scikit-hep/awkward/raw/main/tests/samples/nullable-record-primitives-simple.parquet"
)

needs = dak.inspect.report_necessary_columns(array.u4)
only_u4_array = dak.manual.optimize_columns(array, needs)

assert only_u4_array.fields == ["u4", "u8"]

materialized_only_u4_array = only_u4_array.compute()

# u4 is materialized, u8 is not
assert isinstance(
materialized_only_u4_array.layout.content("u4").content.data, np.ndarray
)
assert isinstance(
materialized_only_u4_array.layout.content("u8").content.data,
ak._nplikes.placeholder.PlaceholderArray,
)

# now again, but we add 'u8' by hand to the columns
key, cols = needs.popitem()
cols |= {"u8"}

needs = {key: cols}

u4_and_u8_array = dak.manual.optimize_columns(array, needs)

assert u4_and_u8_array.fields == ["u4", "u8"]

materialized_u4_and_u8_array = u4_and_u8_array.compute()

# now u4 and u8 are materialized
assert isinstance(
materialized_u4_and_u8_array.layout.content("u4").content.data, np.ndarray
)
assert isinstance(
materialized_u4_and_u8_array.layout.content("u8").content.data, np.ndarray
)

0 comments on commit 9d6ccfd

Please sign in to comment.