diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 17af655c02e..9b40a323f39 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -71,6 +71,8 @@ Bug fixes By `Kai Mühlbauer `_. - Use zarr-fixture to prevent thread leakage errors (:pull:`9967`). By `Kai Mühlbauer `_. +- Fix weighted ``polyfit`` for arrays with more than two dimensions (:issue:`9972`, :pull:`9974`). + By `Mattia Almansi `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a943d9bfc57..74f90ce9eea 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -9206,7 +9206,7 @@ def polyfit( present_dims.update(other_dims) if w is not None: - rhs = rhs * w[:, np.newaxis] + rhs = rhs * w.reshape(-1, *((1,) * len(other_dims))) with warnings.catch_warnings(): if full: # Copy np.polyfit behavior diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 8a90a05a4e3..f3867bd67d2 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6685,11 +6685,15 @@ def test_polyfit_output(self) -> None: assert len(out.data_vars) == 0 def test_polyfit_weighted(self) -> None: - # Make sure weighted polyfit does not change the original object (issue #5644) ds = create_test_data(seed=1) + ds = ds.broadcast_like(ds) # test more than 2 dimensions (issue #9972) ds_copy = ds.copy(deep=True) - ds.polyfit("dim2", 2, w=np.arange(ds.sizes["dim2"])) + expected = ds.polyfit("dim2", 2) + actual = ds.polyfit("dim2", 2, w=np.ones(ds.sizes["dim2"])) + xr.testing.assert_identical(expected, actual) + + # Make sure weighted polyfit does not change the original object (issue #5644) xr.testing.assert_identical(ds, ds_copy) def test_polyfit_coord(self) -> None: diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 7b295128d63..21870050034 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1240,8 +1240,12 @@ def test_repr_inherited_dims(self) -> None: ) def test_doc_example(self) -> None: # regression test for https://github.com/pydata/xarray/issues/9499 - time = xr.DataArray(data=["2022-01", "2023-01"], dims="time") - stations = xr.DataArray(data=list("abcdef"), dims="station") + time = xr.DataArray( + data=np.array(["2022-01", "2023-01"], dtype="