Skip to content

Commit

Permalink
reorder tuple to access fields first; adjust tuple with list as well
Browse files Browse the repository at this point in the history
  • Loading branch information
douglasdavis committed Jan 25, 2024
1 parent 4639ff5 commit b81441e
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 18 deletions.
7 changes: 3 additions & 4 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from dask_awkward.utils import (
DaskAwkwardNotImplemented,
IncompatiblePartitions,
field_access_to_front,
first,
hyphenize,
is_empty_slice,
Expand Down Expand Up @@ -1396,18 +1397,16 @@ def _getitem_slice_on_zero(self, where):
)

def _getitem_tuple(self, where):
where = field_access_to_front(where)
if isinstance(where[0], int):
return self._getitem_outer_int(where)

elif isinstance(where[0], str):
elif isinstance(where[0], (str, list)):
first, rest = where[0], where[1:]
if rest:
return self[first][rest]
return self[first]

elif isinstance(where[0], list):
return self._getitem_outer_str_or_list(where)

elif isinstance(where[0], slice) and is_empty_slice(where[0]):
return self._getitem_trivial_map_partitions(where)

Expand Down
22 changes: 21 additions & 1 deletion src/dask_awkward/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections.abc import Callable, Iterable, Mapping
from typing import TYPE_CHECKING, Any, TypeVar
from typing import TYPE_CHECKING, Any, Sequence, TypeVar

from typing_extensions import ParamSpec

Expand Down Expand Up @@ -149,3 +149,23 @@ def second(seq: Iterable[T]) -> T:
the_iter = iter(seq)
next(the_iter)
return next(the_iter)


def field_access_like(entry: Any) -> bool:
if isinstance(entry, str):
return True
if isinstance(entry, (list, tuple)) and all(isinstance(x, str) for x in entry):
return True
return False


def field_access_to_front(seq: Sequence[Any]) -> tuple[Any, ...]:
new_seq = []
n_front = 0
for entry in seq:
if field_access_like(entry):
new_seq.insert(n_front, entry)
n_front += 1
else:
new_seq.append(entry)
return tuple(new_seq)
24 changes: 12 additions & 12 deletions tests/test_getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,31 +161,31 @@ def test_firstarg_ellipsis_bad() -> None:
daa[..., 0]


def test_multiarg_starting_with_string_gh454():
@pytest.mark.parametrize("i", [0, 1, 2, 3])
def test_multiarg_starting_with_string_gh454(i):
caa = ak.Array(
[
[
{"a": 1, "b": 5},
{"a": 2, "b": 6},
{"a": -2, "b": -6},
{"a": 1, "b": 5},
{"a": 2, "b": 6},
{"a": -2, "b": -6},
],
[
{"a": 1, "b": 5},
{"a": 2, "b": 6},
{"a": 1, "b": -5},
{"a": -2, "b": 6},
],
[],
[
{"a": 1, "b": 5},
{"a": 2, "b": 6},
{"a": -1, "b": 5},
{"a": -2, "b": 6},
],
]
)
daa = dak.from_awkward(caa, npartitions=2)
assert_eq(daa["a", 0], caa["a", 0])
assert_eq(daa["a", 1], caa["a", 1])
assert_eq(daa["a", 2], caa["a", 2])
assert_eq(daa["a", 3], caa["a", 3])
assert daa.defined_divisions
assert_eq(daa["a", i], caa["a", i])

with pytest.raises(ValueError, match="only works when divisions are known"):
daa["a", 0].defined_divisions

assert_eq(daa[["a", "b"], i], caa[["a", "b"], i])
42 changes: 41 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from __future__ import annotations

from dask_awkward.utils import LazyInputsDict, hyphenize, is_empty_slice
import pytest

from dask_awkward.utils import (
LazyInputsDict,
hyphenize,
is_empty_slice,
field_access_to_front,
)


def test_is_empty_slice() -> None:
Expand Down Expand Up @@ -30,3 +37,36 @@ def test_hyphenize() -> None:
assert hyphenize("with_name") == "with-name"
assert hyphenize("with_a_name") == "with-a-name"
assert hyphenize("ok") == "ok"


@pytest.mark.parametrize(
"pairs",
[
(
(1, 3, 2, "z", "a"),
("z", "a", 1, 3, 2),
),
(
("a", 1, 2, ["1", "2"]),
("a", ["1", "2"], 1, 2),
),
(
(0, ["a", "b", "c"]),
(["a", "b", "c"], 0),
),
(
("hello", "abc"),
("hello", "abc"),
),
(
(1, 2, slice(None, None, 2), 3),
(1, 2, slice(None, None, 2), 3),
),
(
(0, ["a", 0], ["a", "b"]),
(["a", "b"], 0, ["a", 0]),
),
],
)
def test_field_access_to_front(pairs):
assert field_access_to_front(pairs[0]) == pairs[1]

0 comments on commit b81441e

Please sign in to comment.