Skip to content

Commit

Permalink
move field access to front, do getitem, then use rest
Browse files Browse the repository at this point in the history
  • Loading branch information
douglasdavis committed Jan 26, 2024
1 parent d86b5fb commit f612e44
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
9 changes: 5 additions & 4 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,15 +1397,16 @@ def _getitem_slice_on_zero(self, where):
)

def _getitem_tuple(self, where):
where = field_access_to_front(where)
where, n_field_accesses = field_access_to_front(where)

if isinstance(where[0], int):
return self._getitem_outer_int(where)

elif isinstance(where[0], (str, list)):
first, rest = where[0], where[1:]
first, rest = where[:n_field_accesses], where[n_field_accesses:]
if rest:
return self[first][rest]
return self[first]
return self._getitem_trivial_map_partitions(first)[rest]
return self._getitem_trivial_map_partitions(first)

elif isinstance(where[0], slice) and is_empty_slice(where[0]):
return self._getitem_trivial_map_partitions(where)
Expand Down
7 changes: 3 additions & 4 deletions src/dask_awkward/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

from collections.abc import Callable, Iterable, Mapping
from collections.abc import Callable, Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, TypeVar
from collections.abc import Sequence

from typing_extensions import ParamSpec

Expand Down Expand Up @@ -160,7 +159,7 @@ def field_access_like(entry: Any) -> bool:
return False


def field_access_to_front(seq: Sequence[Any]) -> tuple[Any, ...]:
def field_access_to_front(seq: Sequence[Any]) -> tuple[tuple[Any, ...], int]:
new_seq = []
n_front = 0
for entry in seq:
Expand All @@ -169,4 +168,4 @@ def field_access_to_front(seq: Sequence[Any]) -> tuple[Any, ...]:
n_front += 1
else:
new_seq.append(entry)
return tuple(new_seq)
return tuple(new_seq), n_front
10 changes: 9 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,36 @@ def test_hyphenize() -> None:
(
(1, 3, 2, "z", "a"),
("z", "a", 1, 3, 2),
2,
),
(
("a", 1, 2, ["1", "2"]),
("a", ["1", "2"], 1, 2),
2,
),
(
(0, ["a", "b", "c"]),
(["a", "b", "c"], 0),
1,
),
(
("hello", "abc"),
("hello", "abc"),
2,
),
(
(1, 2, slice(None, None, 2), 3),
(1, 2, slice(None, None, 2), 3),
0,
),
(
(0, ["a", 0], ["a", "b"]),
(["a", "b"], 0, ["a", 0]),
1,
),
],
)
def test_field_access_to_front(pairs):
assert field_access_to_front(pairs[0]) == pairs[1]
res = field_access_to_front(pairs[0])
assert res[0] == pairs[1]
assert res[1] == pairs[2]

0 comments on commit f612e44

Please sign in to comment.