Skip to content

Commit

Permalink
Fix dataarray drop attrs (pydata#10030)
Browse files Browse the repository at this point in the history
* Fix DataArray().drop_attrs(deep=False)

* Add DataArray().drop_attrs() tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply small cosmetics

* Add support for attrs to DataArray()._replace

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove testing relict

* Fix (try) incompatible types mypy error

* Fix (2.try) incompatible types mypy error

* Update whats-new

* Fix replacing simultaneously passed variable

* Add DataArray()._replace() tests

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
j-haacker and pre-commit-ci[bot] authored Feb 9, 2025
1 parent df2ecf4 commit 54946eb
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 4 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ Bug fixes
- Use mean of min/max years as offset in calculation of datetime64 mean
(:issue:`10019`, :pull:`10035`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
- Fix DataArray().drop_attrs(deep=False) and add support for attrs to
DataArray()._replace(). (:issue:`10027`, :pull:`10030`). By `Jan
Haacker <https://github.com/j-haacker>`_.


Documentation
Expand Down
18 changes: 15 additions & 3 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import copy
import datetime
import warnings
from collections.abc import (
Expand Down Expand Up @@ -522,6 +523,7 @@ def _replace(
variable: Variable | None = None,
coords=None,
name: Hashable | None | Default = _default,
attrs=_default,
indexes=None,
) -> Self:
if variable is None:
Expand All @@ -532,6 +534,11 @@ def _replace(
indexes = self._indexes
if name is _default:
name = self.name
if attrs is _default:
attrs = copy.copy(self.attrs)
else:
variable = variable.copy()
variable.attrs = attrs
return type(self)(variable, coords, name=name, indexes=indexes, fastpath=True)

def _replace_maybe_drop_dims(
Expand Down Expand Up @@ -7575,6 +7582,11 @@ def drop_attrs(self, *, deep: bool = True) -> Self:
-------
DataArray
"""
return (
self._to_temp_dataset().drop_attrs(deep=deep).pipe(self._from_temp_dataset)
)
if not deep:
return self._replace(attrs={})
else:
return (
self._to_temp_dataset()
.drop_attrs(deep=deep)
.pipe(self._from_temp_dataset)
)
20 changes: 19 additions & 1 deletion xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1908,6 +1908,21 @@ def test_rename_dimension_coord_warnings(self) -> None:
warnings.simplefilter("error")
da.rename(x="x")

def test_replace(self) -> None:
# Tests the `attrs` replacement and whether it interferes with a
# `variable` replacement
da = self.mda
attrs1 = {"a1": "val1", "a2": 161}
x = np.ones((10, 20))
v = Variable(["x", "y"], x)
assert da._replace(variable=v, attrs=attrs1).attrs == attrs1
attrs2 = {"b1": "val2", "b2": 1312}
va = Variable(["x", "y"], x, attrs2)
# assuming passed `attrs` should prevail
assert da._replace(variable=va, attrs=attrs1).attrs == attrs1
# assuming `va.attrs` should be adopted
assert da._replace(variable=va).attrs == attrs2

def test_init_value(self) -> None:
expected = DataArray(
np.full((3, 4), 3), dims=["x", "y"], coords=[range(3), range(4)]
Expand Down Expand Up @@ -2991,8 +3006,11 @@ def test_assign_attrs(self) -> None:

def test_drop_attrs(self) -> None:
# Mostly tested in test_dataset.py, but adding a very small test here
da = DataArray([], attrs=dict(a=1, b=2))
coord_ = DataArray([], attrs=dict(d=3, e=4))
da = DataArray([], attrs=dict(a=1, b=2)).assign_coords(dict(coord_=coord_))
assert da.drop_attrs().attrs == {}
assert da.drop_attrs().coord_.attrs == {}
assert da.drop_attrs(deep=False).coord_.attrs == dict(d=3, e=4)

@pytest.mark.parametrize(
"func", [lambda x: x.clip(0, 1), lambda x: np.float64(1.0) * x, np.abs, abs]
Expand Down

0 comments on commit 54946eb

Please sign in to comment.