Skip to content

Commit

Permalink
Merge pull request #554 from pfackeldey/allow_graphnode_renaming
Browse files Browse the repository at this point in the history
fix: allow renaming when rebuilding dask-awkward Arrays
  • Loading branch information
pfackeldey authored Dec 17, 2024
2 parents f35cc3b + 7fe9611 commit 8958f4c
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,6 @@ venv.bak/

# mypy
.mypy_cache/

# pyright
pyrightconfig.json
4 changes: 2 additions & 2 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def __dask_postpersist__(self):
def _rebuild(self, dsk, *, rename=None):
name = self._name
if rename:
raise ValueError("rename= unsupported in dask-awkward")
name = rename.get(name, name)
return type(self)(dsk, name, self._meta, self.known_value)

def __reduce__(self):
Expand Down Expand Up @@ -969,7 +969,7 @@ def __setitem__(self, where: Any, what: Any) -> None:
def _rebuild(self, dsk, *, rename=None):
name = self.name
if rename:
raise ValueError("rename= unsupported in dask-awkward")
name = rename.get(name, name)
return Array(dsk, name, self._meta, divisions=self.divisions)

def reset_meta(self) -> None:
Expand Down
13 changes: 6 additions & 7 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,6 @@ def test_array_rebuild(ndjson_points_file: str) -> None:
y = daa.compute()
assert x.tolist() == y.tolist()

with pytest.raises(ValueError, match="rename= unsupported"):
daa._rebuild(daa.dask, rename={"x": "y"})


def test_type(ndjson_points_file: str) -> None:
daa = dak.from_json([ndjson_points_file] * 2)
Expand Down Expand Up @@ -661,14 +658,16 @@ def test_array_persist(daa: Array) -> None:
assert_eq(daa2, daa)


def test_scalar_persist_and_rebuild(daa: Array) -> None:
def test_scalar_persist(daa: Array) -> None:
coll = daa["points"][0]["x"][0]
coll2 = coll.persist()
assert_eq(coll, coll2)

m = dak.max(daa.points.x, axis=None)
with pytest.raises(ValueError, match="rename= unsupported"):
m._rebuild(m.dask, rename={m._name: "max2"})

def test_array_rename_when_rebuilding(daa: Array) -> None:
name = daa.name
new_name = "foobar"
assert daa._rebuild(dsk=daa.dask, rename={name: new_name}).name == new_name


def test_output_divisions(daa: Array) -> None:
Expand Down

0 comments on commit 8958f4c

Please sign in to comment.