Skip to content

Commit

Permalink
Refactor logic for shifting months to use Xarray instead of Pandas
Browse files Browse the repository at this point in the history
- Months are also shifted in the `_preprocess_dataset()` method now. Before months were being shifted twice, once when dropping incomplete seasons or DJF, and a second time when labeling time coordinates.
  • Loading branch information
tomvothecoder committed Nov 12, 2024
1 parent 05496ed commit 0b6852f
Showing 1 changed file with 145 additions and 154 deletions.
299 changes: 145 additions & 154 deletions xcdat/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,9 +1101,12 @@ def _form_seasons(self, custom_seasons: List[List[str]]) -> Dict[str, List[str]]
def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset:
"""Preprocess the dataset based on averaging settings.
Preprocessing operations include:
- Drop incomplete DJF seasons (leading/trailing)
- Drop leap days
Operations include:
1. Drop leap days for daily climatologies.
2. Subset the dataset based on the reference period.
3. Shift years for custom seasons spanning the calendar year.
4. Shift Decembers for "DJF" mode and drop incomplete "DJF" seasons.
5. Drop incomplete seasons if specified.
Parameters
----------
Expand All @@ -1114,6 +1117,18 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset:
-------
xr.Dataset
"""
if (
self._freq == "day"
and self._mode in ["climatology", "departures"]
and self.calendar in ["gregorian", "proleptic_gregorian", "standard"]
):
ds = self._drop_leap_days(ds)

if self._mode == "climatology" and self._reference_period is not None:
ds = ds.sel(
{self.dim: slice(self._reference_period[0], self._reference_period[1])}
)

if (
self._freq == "season"
and self._season_config.get("custom_seasons") is not None
Expand All @@ -1129,34 +1144,24 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset:
if len(months) != 12:
ds = self._subset_coords_for_custom_seasons(ds, months)

if (
self._freq == "day"
and self._mode in ["climatology", "departures"]
and self.calendar in ["gregorian", "proleptic_gregorian", "standard"]
):
ds = self._drop_leap_days(ds)
ds = self._shift_custom_season_years(ds)

if self._freq == "season" and self._season_config.get("dec_mode") == "DJF":
ds = self._shift_djf_decembers(ds)

# TODO: Deprecate incomplete_djf.
if (
self._season_config.get("drop_incomplete_djf") is True
and self._season_config.get("drop_incomplete_seasons") is False
):
ds = self._drop_incomplete_djf(ds)

if (
self._freq == "season"
and self._season_config["drop_incomplete_seasons"] is True
):
ds = self._drop_incomplete_seasons(ds)

# TODO: Deprecate incomplete_djf. Only run this is drop_incomplete_seasons
# is False and drop_incomplete_djf is True.
if (
self._freq == "season"
and self._season_config.get("dec_mode") == "DJF"
and self._season_config.get("drop_incomplete_djf") is True
and self._season_config.get("drop_incomplete_seasons") is False
):
ds = self._drop_incomplete_djf(ds)

if self._mode == "climatology" and self._reference_period is not None:
ds = ds.sel(
{self.dim: slice(self._reference_period[0], self._reference_period[1])}
)

return ds

def _subset_coords_for_custom_seasons(
Expand Down Expand Up @@ -1191,6 +1196,119 @@ def _subset_coords_for_custom_seasons(

return ds_new

def _shift_custom_season_years(self, ds: xr.Dataset) -> xr.Dataset:
"""Shifts the year for custom seasons spanning the calendar year.
A season spans the calendar year if it includes "Jan" and "Jan" is not
the first month. For example, for
``custom_seasons = ["Nov", "Dec", "Jan", "Feb", "Mar"]``:
- ["Nov", "Dec"] are from the previous year.
- ["Jan", "Feb", "Mar"] are from the current year.
Therefore, ["Nov", "Dec"] need to be shifted a year forward for correct
grouping.
Parameters
----------
ds : xr.Dataset
The Dataset with time coordinates.
Returns
-------
xr.Dataset
The Dataset with shifted time coordinates.
Examples
--------
Before and after shifting months for "NDJFM" seasons:
>>> # Before shifting months
>>> [(2000, "NDJFM", 11), (2000, "NDJFM", 12), (2001, "NDJFM", 1),
>>> (2001, "NDJFM", 2), (2001, "NDJFM", 3)]
>>> # After shifting months
>>> [(2001, "NDJFM", 11), (2001, "NDJFM", 12), (2001, "NDJFM", 1),
>>> (2001, "NDJFM", 2), (2001, "NDJFM", 3)]
"""
ds_new = ds.copy()
custom_seasons = self._season_config["custom_seasons"]

span_months: List[int] = []

# Identify the months that span across years in custom seasons.
# This is done by checking if "Jan" is not the first month in the
# custom season and getting all months before "Jan".
for months in custom_seasons.values(): # type: ignore
month_nums = [MONTH_STR_TO_INT[month] for month in months]
if 1 in month_nums:
jan_index = month_nums.index(1)

if jan_index != 0:
span_months.extend(month_nums[:jan_index])
break

if span_months:
time_coords = ds_new[self.dim].copy()
idxs = np.where(time_coords.dt.month.isin(span_months))[0]

if isinstance(time_coords.values[0], cftime.datetime):
for idx in idxs:
time_coords.values[idx] = time_coords.values[idx].replace(
year=time_coords.values[idx].year + 1
)
else:
for idx in idxs:
time_coords.values[idx] = pd.Timestamp(
time_coords.values[idx]
) + pd.DateOffset(years=1)

ds_new = ds_new.assign_coords({self.dim: time_coords})

return ds_new

def _shift_djf_decembers(self, ds: xr.Dataset) -> xr.Dataset:
"""Shifts Decembers to the next year for "DJF" seasons.
This ensures correct grouping for "DJF" seasons by shifting Decembers
to the next year. Without this, grouping defaults to "JFD", which
is the native Xarray behavior.
Parameters
----------
ds : xr.Dataset
The Dataset with time coordinates.
Returns
-------
xr.Dataset
The Dataset with shifted time coordinates.
Examples
--------
Comparison of "JFD" and "DJF" seasons:
>>> # "JFD" (native xarray behavior)
>>> [(2000, "DJF", 1), (2000, "DJF", 2), (2000, "DJF", 12),
>>> (2001, "DJF", 1), (2001, "DJF", 2)]
>>> # "DJF" (shifted Decembers)
>>> [(2000, "DJF", 1), (2000, "DJF", 2), (2001, "DJF", 12),
>>> (2001, "DJF", 1), (2001, "DJF", 2)]
"""
ds_new = ds.copy()
time_coords = ds_new[self.dim].copy()
dec_indexes = time_coords.dt.month == 12

time_coords.values[dec_indexes] = [
time.replace(year=time.year + 1) for time in time_coords.values[dec_indexes]
]

ds_new = ds_new.assign_coords({self.dim: time_coords})

return ds_new

def _drop_incomplete_djf(self, dataset: xr.Dataset) -> xr.Dataset:
"""Drops incomplete DJF seasons within a continuous time series.
Expand Down Expand Up @@ -1612,41 +1730,16 @@ def _get_df_dt_components(
elif self._mode == "group_average":
df["month"] = time_coords[f"{self.dim}.month"].values

df = self._process_season_df(df)
custom_seasons = self._season_config.get("custom_seasons")
if custom_seasons is not None:
df = self._map_months_to_custom_seasons(df)

if drop_obsolete_cols:
df = self._drop_obsolete_columns(df)
df = self._map_seasons_to_mid_months(df)

return df

def _process_season_df(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Processes a DataFrame of datetime components for the season frequency.
Parameters
----------
df : xr.DataArray
A DataFrame of seasonal datetime components.
Returns
-------
pd.DataFrame
A DataFrame of seasonal datetime components.
"""
df_new = df.copy()
custom_seasons = self._season_config.get("custom_seasons")
dec_mode = self._season_config.get("dec_mode")

if custom_seasons is not None:
df_new = self._map_months_to_custom_seasons(df_new)
df_new = self._shift_spanning_months(df_new)
else:
if dec_mode == "DJF":
df_new = self._shift_decembers(df_new)

return df_new

def _map_months_to_custom_seasons(self, df: pd.DataFrame) -> pd.DataFrame:
"""Maps the month column in the DataFrame to a custom season.
Expand Down Expand Up @@ -1681,108 +1774,6 @@ def _map_months_to_custom_seasons(self, df: pd.DataFrame) -> pd.DataFrame:

return df_new

def _shift_spanning_months(self, df: pd.DataFrame) -> pd.DataFrame:
"""Shifts months in seasons spanning the previous year to the next year.
A season spans the previous year if it includes the month of "Jan" and
"Jan" is not the first month of the season. For example, let's say we
define ``custom_seasons = ["Nov", "Dec", "Jan", "Feb", "Mar"]`` to
represent the southern hemisphere growing seasons, "NDJFM".
- ["Nov", "Dec"] are from the previous year since they are listed
before "Jan".
- ["Jan", "Feb", "Mar"] are from the current year.
Therefore, we need to shift ["Nov", "Dec"] a year forward in order for
xarray to group seasons correctly. Refer to the examples section below
for a visual demonstration.
Parameters
----------
df : pd.Dataframe
The DataFrame of xarray datetime components produced using the
"season" frequency".
Returns
-------
pd.DataFrame
The DataFrame of xarray dataetime copmonents with months spanning
previous year shifted over to the next year.
Examples
--------
Before and after shifting months for "NDJFM" seasons:
>>> # Before shifting months
>>> [(2000, "NDJFM", 11), (2000, "NDJFM", 12), (2001, "NDJFM", 1),
>>> (2001, "NDJFM", 2), (2001, "NDJFM", 3)]
>>> # After shifting months
>>> [(2001, "NDJFM", 11), (2001, "NDJFM", 12), (2001, "NDJFM", 1),
>>> (2001, "NDJFM", 1), (2001, "NDJFM", 2)]
"""
df_new = df.copy()
custom_seasons = self._season_config["custom_seasons"]

span_months: List[int] = []

# Loop over the custom seasons and get the list of months for the
# current season. Convert those months to their integer representations.
# If 1 ("Jan") is in the list of months and it is NOT the first element,
# then get all elements before it (aka the spanning months).
for months in custom_seasons.values(): # type: ignore
month_nums = [MONTH_STR_TO_INT[month] for month in months]
try:
jan_index = month_nums.index(1)
if jan_index != 0:
span_months = span_months + month_nums[:jan_index]
break
except ValueError:
continue

if len(span_months) > 0:
df_new.loc[df_new["month"].isin(span_months), "year"] = df_new["year"] + 1

return df_new

def _shift_decembers(self, df_season: pd.DataFrame) -> pd.DataFrame:
"""Shifts Decembers over to the next year for "DJF" seasons in-place.
For "DJF" seasons, Decembers must be shifted over to the next year in
order for the xarray groupby operation to correctly label and group the
corresponding time coordinates. If the aren't shifted over, grouping is
incorrectly performed with the native xarray "DJF" season (which is
actually "JFD").
Parameters
----------
df_season : pd.DataFrame
The DataFrame of xarray datetime components produced using the
"season" frequency.
Returns
-------
pd.DataFrame
The DataFrame of xarray datetime components with Decembers shifted
over to the next year.
Examples
--------
Comparison of "JFD" and "DJF" seasons:
>>> # "JFD" (native xarray behavior)
>>> [(2000, "DJF", 1), (2000, "DJF", 2), (2000, "DJF", 12),
>>> (2001, "DJF", 1), (2001, "DJF", 2)]
>>> # "DJF" (shifted Decembers)
>>> [(2000, "DJF", 1), (2000, "DJF", 2), (2001, "DJF", 12),
>>> (2001, "DJF", 1), (2001, "DJF", 2)]
"""
df_season.loc[df_season["month"] == 12, "year"] = df_season["year"] + 1

return df_season

def _map_seasons_to_mid_months(self, df: pd.DataFrame) -> pd.DataFrame:
"""Maps the season column values to the integer of its middle month.
Expand Down

0 comments on commit 0b6852f

Please sign in to comment.