From 4c3845d7c228478c4cbd4a0c816c07ec1bbf2165 Mon Sep 17 00:00:00 2001 From: tomvothecoder Date: Fri, 30 Aug 2024 16:07:55 -0700 Subject: [PATCH] Extract `_calculate_departures()` method --- tests/test_temporal.py | 60 ++++++++++----------- xcdat/temporal.py | 118 ++++++++++++++++++++++++++--------------- 2 files changed, 105 insertions(+), 73 deletions(-) diff --git a/tests/test_temporal.py b/tests/test_temporal.py index 9cd7c420..4385cfab 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -121,7 +121,7 @@ def test_averages_for_yearly_time_series(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) # Test unweighted averages result = ds.temporal.average("ts", weighted=False) @@ -139,7 +139,7 @@ def test_averages_for_yearly_time_series(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_averages_for_monthly_time_series(self): # Set up dataset @@ -293,7 +293,7 @@ def test_averages_for_daily_time_series(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) # Test unweighted averages result = ds.temporal.average("ts", weighted=False) @@ -310,7 +310,7 @@ def test_averages_for_daily_time_series(self): "weighted": "False", }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_averages_for_hourly_time_series(self): ds = xr.Dataset( @@ -378,7 +378,7 @@ def test_averages_for_hourly_time_series(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) # Test unweighted averages result = ds.temporal.average("ts", weighted=False) @@ -396,7 +396,7 @@ def test_averages_for_hourly_time_series(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) class TestGroupAverage: @@ -619,7 +619,7 @@ def test_weighted_seasonal_averages_with_DJF_and_drop_incomplete_seasons(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_seasonal_averages_with_DJF_without_dropping_incomplete_seasons( self, @@ -670,7 +670,7 @@ def test_weighted_seasonal_averages_with_DJF_without_dropping_incomplete_seasons }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_seasonal_averages_with_JFD(self): ds = self.ds.copy() @@ -729,7 +729,7 @@ def test_weighted_seasonal_averages_with_JFD(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_custom_seasonal_averages(self): ds = self.ds.copy() @@ -787,7 +787,7 @@ def test_weighted_custom_seasonal_averages(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_raises_error_with_incorrect_custom_seasons_argument(self): # Test raises error with non-3 letter strings @@ -873,7 +873,7 @@ def test_weighted_monthly_averages(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_monthly_averages_with_masked_data(self): ds = self.ds.copy() @@ -924,7 +924,7 @@ def test_weighted_monthly_averages_with_masked_data(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_daily_averages(self): ds = self.ds.copy() @@ -967,7 +967,7 @@ def test_weighted_daily_averages(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_hourly_averages(self): ds = self.ds.copy() @@ -1011,7 +1011,7 @@ def test_weighted_hourly_averages(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) class TestClimatology: @@ -1105,7 +1105,7 @@ def test_subsets_climatology_based_on_reference_period(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_seasonal_climatology_with_DJF(self): ds = self.ds.copy() @@ -1159,7 +1159,7 @@ def test_weighted_seasonal_climatology_with_DJF(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) @requires_dask def test_chunked_weighted_seasonal_climatology_with_DJF(self): @@ -1214,7 +1214,7 @@ def test_chunked_weighted_seasonal_climatology_with_DJF(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_seasonal_climatology_with_JFD(self): ds = self.ds.copy() @@ -1265,7 +1265,7 @@ def test_weighted_seasonal_climatology_with_JFD(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_custom_seasonal_climatology(self): ds = self.ds.copy() @@ -1328,7 +1328,7 @@ def test_weighted_custom_seasonal_climatology(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_monthly_climatology(self): result = self.ds.temporal.climatology("ts", "month") @@ -1391,7 +1391,7 @@ def test_weighted_monthly_climatology(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_unweighted_monthly_climatology(self): result = self.ds.temporal.climatology("ts", "month", weighted=False) @@ -1453,7 +1453,7 @@ def test_unweighted_monthly_climatology(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_daily_climatology(self): result = self.ds.temporal.climatology("ts", "day", weighted=True) @@ -1515,7 +1515,7 @@ def test_weighted_daily_climatology(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_daily_climatology_drops_leap_days_with_matching_calendar(self): time = xr.DataArray( @@ -1606,7 +1606,7 @@ def test_weighted_daily_climatology_drops_leap_days_with_matching_calendar(self) }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_unweighted_daily_climatology(self): result = self.ds.temporal.climatology("ts", "day", weighted=False) @@ -1668,7 +1668,7 @@ def test_unweighted_daily_climatology(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) class TestDepartures: @@ -1895,7 +1895,7 @@ def test_monthly_departures_relative_to_climatology_reference_period_with_same_o }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_seasonal_departures_with_DJF(self): ds = self.ds.copy() @@ -1945,7 +1945,7 @@ def test_weighted_seasonal_departures_with_DJF(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_seasonal_departures_with_DJF_and_keep_weights(self): ds = self.ds.copy() @@ -2021,7 +2021,7 @@ def test_weighted_seasonal_departures_with_DJF_and_keep_weights(self): dims=["time_original"], ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_unweighted_seasonal_departures_with_DJF(self): ds = self.ds.copy() @@ -2071,7 +2071,7 @@ def test_unweighted_seasonal_departures_with_DJF(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_unweighted_seasonal_departures_with_JFD(self): ds = self.ds.copy() @@ -2121,7 +2121,7 @@ def test_unweighted_seasonal_departures_with_JFD(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_daily_departures_drops_leap_days_with_matching_calendar(self): time = xr.DataArray( @@ -2214,7 +2214,7 @@ def test_weighted_daily_departures_drops_leap_days_with_matching_calendar(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) class Test_GetWeights: diff --git a/xcdat/temporal.py b/xcdat/temporal.py index 9b952c79..8ca7f729 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -747,16 +747,16 @@ def departures( # calling the group average or climatology APIs in step #3. ds = self._preprocess_dataset(ds) - # 3. Calculate the grouped average and climatology of the data variable. + # 3. Get the observational data variable. # ---------------------------------------------------------------------- - # The climatology and grouped average APIs are called on the copied - # dataset to create separate instances of the `TemporalAccessor` class. - # This is done to avoid overriding the attributes of the current - # instance of `TemporalAccessor` (set in step #1 above). + # NOTE: The xCDAT APIs are called on copies of the original dataset to + # create separate instances of the `TemporalAccessor` class. This is + # done to avoid overriding the attributes of the current instance of + # `TemporalAccessor`, which is set by step #1. + ds_obs = ds.copy() # Group averaging is only required if the dataset's frequency (input) # differs from the `freq` arg (output). - ds_obs = ds.copy() inferred_freq = _infer_freq(ds[self.dim]) if inferred_freq != freq: ds_obs = ds_obs.temporal.group_average( @@ -767,7 +767,10 @@ def departures( season_config, ) - ds_climo = ds.temporal.climatology( + # 4. Calculate the climatology of the data variable. + # ---------------------------------------------------------------------- + ds_climo = ds.copy() + ds_climo = ds_climo.temporal.climatology( data_var, freq, weighted, @@ -776,35 +779,11 @@ def departures( season_config, ) - # 4. Group the averaged data variable values by the time `freq`. - # ---------------------------------------------------------------------- - # This step allows us to perform xarray's grouped arithmetic to - # calculate departures. - dv_obs = ds_obs[data_var].copy() - self._labeled_time = self._label_time_coords(dv_obs[self.dim]) - dv_obs_grouped = self._group_data(dv_obs) - - # 5. Align time dimension names using the labeled time dimension name. - # ---------------------------------------------------------------------- - dv_climo = ds_climo[data_var] - - # 6. Calculate the departures for the data variable. + # 5. Calculate the departures for the data variable. # ---------------------------------------------------------------------- - # departures = observation - climatology - with xr.set_options(keep_attrs=True): - dv_departs = dv_obs_grouped - dv_climo - dv_departs = self._add_operation_attrs(dv_departs) - - # The original time dimension is dropped from the dataset to - # accomodate the new time dimension and its associated coordinates. - ds_obs = ds_obs.drop_dims(str(self.dim)) - ds_obs[data_var] = dv_departs + ds_departs = self._calculate_departures(ds_obs, ds_climo, data_var) - if weighted and keep_weights: - self._weights = ds_climo.time_wts - ds_obs = self._keep_weights(ds_obs) - - return ds_obs + return ds_departs def _averager( self, @@ -1306,7 +1285,7 @@ def _group_data(self, data_var: xr.DataArray) -> DataArrayGroupBy: dv_gb = dv.groupby(f"{self.dim}.{self._freq}") else: dv = dv.assign_coords({self.dim: self._labeled_time}) - dv_gb = dv.groupby(self._labeled_time.name) + dv_gb = dv.groupby(self.dim) return dv_gb @@ -1653,6 +1632,10 @@ def _convert_df_to_dt(self, df: pd.DataFrame) -> np.ndarray: def _keep_weights(self, ds: xr.Dataset) -> xr.Dataset: """Keep the weights in the dataset. + The labeled time coordinates for the weights are replaced with the + original time coordinates and the dimension name is appended with + "_original". + Parameters ---------- ds : xr.Dataset @@ -1663,16 +1646,11 @@ def _keep_weights(self, ds: xr.Dataset) -> xr.Dataset: xr.Dataset The dataset with the weights used for averaging. """ - # Append "_original" to the name of the weights` time coordinates to - # avoid conflict with the grouped time coordinates in the Dataset (can - # have a different shape). if self._mode in ["group_average", "climatology"]: - self._weights = self._weights.rename({self.dim: f"{self.dim}_original"}) - # Only keep the original time coordinates, not the ones labeled - # by group. - self._weights = self._weights.drop_vars(self._labeled_time.name) + weights = self._weights.assign_coords({self.dim: self._dataset[self.dim]}) + weights = weights.rename({self.dim: f"{self.dim}_original"}) - ds[self._weights.name] = self._weights + ds[weights.name] = weights return ds @@ -1717,6 +1695,60 @@ def _add_operation_attrs(self, data_var: xr.DataArray) -> xr.DataArray: return data_var + def _calculate_departures( + self, + ds_obs: xr.Dataset, + ds_climo: xr.Dataset, + data_var: str, + ) -> xr.Dataset: + """Calculate the departures for a data variable. + + How this methods works: + + 1. Label the observational data variable's time coordinates by their + appropriate time group. For example, the first two time + coordinates 2000-01-01 and 2000-02-01 are replaced with the + "01-01-01" and "01-02-01" monthly groups. + 2. Calculate departures by subtracting the climatology from the + labeled observational data using Xarray's grouped arithmetic with + automatic broadcasting (departures = obs - climo). + 3. Restore the original time coordinates to the departures variable + to preserve the "year" of the time coordinates. For example, + the first two time coordinates 01-01-01 and 01-02-01 are reverted + back to 2000-01-01 and 2000-02-01. + + Parameters + ---------- + ds_obs : xr.Dataset + The observational dataset. + dv_climo : xr.Dataset + The climatology dataset. + data_var : str + The key of the data variable for calculating departures. + + Returns + ------- + xr.Dataset + The dataset containing the departures for a data variable. + """ + ds_departs = ds_obs.copy() + + dv_obs = ds_obs[data_var].copy() + self._labeled_time = self._label_time_coords(dv_obs[self.dim]) + dv_obs_grouped = self._group_data(dv_obs) + + dv_climo = ds_climo[data_var].copy() + + with xr.set_options(keep_attrs=True): + dv_departs = dv_obs_grouped - dv_climo + + dv_departs = self._add_operation_attrs(dv_departs) + + dv_departs = dv_departs.assign_coords({self.dim: ds_obs[self.dim]}) + ds_departs[data_var] = dv_departs + + return ds_departs + def _infer_freq(time_coords: xr.DataArray) -> Frequency: """Infers the time frequency from the coordinates.