Skip to content

Commit

Permalink
refactor according to review concerns and suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
kmuehlbauer committed Feb 25, 2025
1 parent befa79e commit 031aaa2
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 21 deletions.
20 changes: 15 additions & 5 deletions xarray/coding/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,7 +1316,14 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
units = encoding.pop("units", None)
calendar = encoding.pop("calendar", None)
dtype = encoding.pop("dtype", None)

# in the case of packed data we need to encode into
# float first, the correct dtype will be established
# via CFScaleOffsetCoder/CFMaskCoder
if "add_offset" in encoding or "scale_factor" in encoding:
dtype = data.dtype if data.dtype.kind == "f" else "float64"
(data, units, calendar) = encode_cf_datetime(data, units, calendar, dtype)

# if no dtype is provided, preserve data.dtype in encoding
if dtype is None:
safe_setitem(encoding, "dtype", data.dtype, name=name)
Expand Down Expand Up @@ -1371,17 +1378,20 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
if np.issubdtype(variable.data.dtype, np.timedelta64):
dims, data, attrs, encoding = unpack_for_encoding(variable)

dtype = encoding.pop("dtype", None)

# in the case of packed data we need to encode into
# float first, the correct dtype will be established
# via CFScaleOffsetCoder/CFMaskCoder
dtype = None
if "add_offset" in encoding or "scale_factor" in encoding:
encoding.pop("dtype")
dtype = data.dtype if data.dtype.kind == "f" else "float64"

data, units = encode_cf_timedelta(
data, encoding.pop("units", None), encoding.get("dtype", dtype)
)
data, units = encode_cf_timedelta(data, encoding.pop("units", None), dtype)

# if no dtype is provided, preserve data.dtype in encoding
if dtype is None:
safe_setitem(encoding, "dtype", data.dtype, name=name)

safe_setitem(attrs, "units", units, name=name)

return Variable(dims, data, attrs, encoding, fastpath=True)
Expand Down
35 changes: 20 additions & 15 deletions xarray/coding/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,25 +484,30 @@ def decode(self, variable: Variable, name: T_Name = None):
)

if encoded_fill_values:
# special case DateTime to properly handle NaT
# we need to check if time-like will be decoded or not
# in further processing
dtype: np.typing.DTypeLike
decoded_fill_value: Any
is_time_like = _is_time_like(attrs.get("units"))
if (
(is_time_like == "datetime" and self.decode_times)
or (is_time_like == "timedelta" and self.decode_timedelta)
) and data.dtype.kind in "iu":
dtype, decoded_fill_value = np.int64, np.iinfo(np.int64).min
# in case of packed data we have to decode into float
# in any case
if "scale_factor" in attrs or "add_offset" in attrs:
dtype, decoded_fill_value = (
_choose_float_dtype(data.dtype, attrs),
np.nan,
)
else:
if "scale_factor" not in attrs and "add_offset" not in attrs:
dtype, decoded_fill_value = dtypes.maybe_promote(data.dtype)
else:
# in case of no-packing special case DateTime/Timedelta to properly
# handle NaT, we need to check if time-like will be decoded
# or not in further processing
is_time_like = _is_time_like(attrs.get("units"))
if (
(is_time_like == "datetime" and self.decode_times)
or (is_time_like == "timedelta" and self.decode_timedelta)
) and data.dtype.kind in "iu":
dtype, decoded_fill_value = (
_choose_float_dtype(data.dtype, attrs),
np.nan,
)
np.int64,
np.iinfo(np.int64).min,
) # np.dtype(f"{is_time_like}64[s]")
else:
dtype, decoded_fill_value = dtypes.maybe_promote(data.dtype)

transform = partial(
_apply_mask,
Expand Down
20 changes: 19 additions & 1 deletion xarray/tests/test_coding_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,24 @@ def test_decoded_cf_datetime_array_2d(time_unit: PDDatetimeUnitOptions) -> None:
assert_array_equal(np.asarray(result), expected)


@pytest.mark.parametrize("decode_times", [True, False])
@pytest.mark.parametrize("mask_and_scale", [True, False])
def test_decode_datetime_mask_and_scale(
decode_times: bool, mask_and_scale: bool
) -> None:
attrs = {
"units": "nanoseconds since 1970-01-01",
"_FillValue": np.int16(-1),
"add_offset": 100000.0,
}
encoded = Variable(["time"], np.array([0, -1, 1], "int16"), attrs=attrs)
decoded = conventions.decode_cf_variable(
"foo", encoded, mask_and_scale=mask_and_scale, decode_times=decode_times
)
result = conventions.encode_cf_variable(decoded, name="foo")
assert_equal(encoded, result)


FREQUENCIES_TO_ENCODING_UNITS = {
"ns": "nanoseconds",
"us": "microseconds",
Expand Down Expand Up @@ -1914,7 +1932,7 @@ def test_lazy_decode_timedelta_error() -> None:
def test_decode_timedelta_mask_and_scale(
decode_timedelta: bool, mask_and_scale: bool
) -> None:
attrs = {"units": "days", "_FillValue": np.int16(-1), "add_offset": 100.0}
attrs = {"units": "nanoseconds", "_FillValue": np.int16(-1), "add_offset": 100000.0}
encoded = Variable(["time"], np.array([0, -1, 1], "int16"), attrs=attrs)
decoded = conventions.decode_cf_variable(
"foo", encoded, mask_and_scale=mask_and_scale, decode_timedelta=decode_timedelta
Expand Down

0 comments on commit 031aaa2

Please sign in to comment.