diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index c1e205a2..d4f7e622 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -1397,15 +1397,16 @@ def _getitem_slice_on_zero(self, where): ) def _getitem_tuple(self, where): - where = field_access_to_front(where) + where, n_field_accesses = field_access_to_front(where) + if isinstance(where[0], int): return self._getitem_outer_int(where) elif isinstance(where[0], (str, list)): - first, rest = where[0], where[1:] + first, rest = where[:n_field_accesses], where[n_field_accesses:] if rest: - return self[first][rest] - return self[first] + return self._getitem_trivial_map_partitions(first)[rest] + return self._getitem_trivial_map_partitions(first) elif isinstance(where[0], slice) and is_empty_slice(where[0]): return self._getitem_trivial_map_partitions(where) diff --git a/src/dask_awkward/utils.py b/src/dask_awkward/utils.py index b8609226..3deb3b38 100644 --- a/src/dask_awkward/utils.py +++ b/src/dask_awkward/utils.py @@ -1,8 +1,7 @@ from __future__ import annotations -from collections.abc import Callable, Iterable, Mapping +from collections.abc import Callable, Iterable, Mapping, Sequence from typing import TYPE_CHECKING, Any, TypeVar -from collections.abc import Sequence from typing_extensions import ParamSpec @@ -160,7 +159,7 @@ def field_access_like(entry: Any) -> bool: return False -def field_access_to_front(seq: Sequence[Any]) -> tuple[Any, ...]: +def field_access_to_front(seq: Sequence[Any]) -> tuple[tuple[Any, ...], int]: new_seq = [] n_front = 0 for entry in seq: @@ -169,4 +168,4 @@ def field_access_to_front(seq: Sequence[Any]) -> tuple[Any, ...]: n_front += 1 else: new_seq.append(entry) - return tuple(new_seq) + return tuple(new_seq), n_front diff --git a/tests/test_utils.py b/tests/test_utils.py index 19b3a7b7..3d31a964 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -45,28 +45,36 @@ def test_hyphenize() -> None: ( (1, 3, 2, "z", "a"), ("z", "a", 1, 3, 2), + 2, ), ( ("a", 1, 2, ["1", "2"]), ("a", ["1", "2"], 1, 2), + 2, ), ( (0, ["a", "b", "c"]), (["a", "b", "c"], 0), + 1, ), ( ("hello", "abc"), ("hello", "abc"), + 2, ), ( (1, 2, slice(None, None, 2), 3), (1, 2, slice(None, None, 2), 3), + 0, ), ( (0, ["a", 0], ["a", "b"]), (["a", "b"], 0, ["a", 0]), + 1, ), ], ) def test_field_access_to_front(pairs): - assert field_access_to_front(pairs[0]) == pairs[1] + res = field_access_to_front(pairs[0]) + assert res[0] == pairs[1] + assert res[1] == pairs[2]