Skip to content

Commit

Permalink
retain retain dtype for packed data in datetime/timedelta encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
kmuehlbauer committed Feb 25, 2025
1 parent 031aaa2 commit 03aa1b8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
16 changes: 10 additions & 6 deletions xarray/coding/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,13 +1320,15 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
# in the case of packed data we need to encode into
# float first, the correct dtype will be established
# via CFScaleOffsetCoder/CFMaskCoder
set_dtype_encoding = None
if "add_offset" in encoding or "scale_factor" in encoding:
set_dtype_encoding = dtype
dtype = data.dtype if data.dtype.kind == "f" else "float64"
(data, units, calendar) = encode_cf_datetime(data, units, calendar, dtype)

# if no dtype is provided, preserve data.dtype in encoding
if dtype is None:
safe_setitem(encoding, "dtype", data.dtype, name=name)
# retain dtype for packed data
if set_dtype_encoding is not None:
safe_setitem(encoding, "dtype", set_dtype_encoding, name=name)
safe_setitem(attrs, "units", units, name=name)
safe_setitem(attrs, "calendar", calendar, name=name)

Expand Down Expand Up @@ -1383,14 +1385,16 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
# in the case of packed data we need to encode into
# float first, the correct dtype will be established
# via CFScaleOffsetCoder/CFMaskCoder
set_dtype_encoding = None
if "add_offset" in encoding or "scale_factor" in encoding:
set_dtype_encoding = dtype
dtype = data.dtype if data.dtype.kind == "f" else "float64"

data, units = encode_cf_timedelta(data, encoding.pop("units", None), dtype)

# if no dtype is provided, preserve data.dtype in encoding
if dtype is None:
safe_setitem(encoding, "dtype", data.dtype, name=name)
# retain dtype for packed data
if set_dtype_encoding is not None:
safe_setitem(encoding, "dtype", set_dtype_encoding, name=name)

safe_setitem(attrs, "units", units, name=name)

Expand Down
7 changes: 5 additions & 2 deletions xarray/tests/test_coding_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ def test_decode_datetime_mask_and_scale(
) -> None:
attrs = {
"units": "nanoseconds since 1970-01-01",
"calendar": "proleptic_gregorian",
"_FillValue": np.int16(-1),
"add_offset": 100000.0,
}
Expand All @@ -540,7 +541,8 @@ def test_decode_datetime_mask_and_scale(
"foo", encoded, mask_and_scale=mask_and_scale, decode_times=decode_times
)
result = conventions.encode_cf_variable(decoded, name="foo")
assert_equal(encoded, result)
assert_identical(encoded, result)
assert encoded.dtype == result.dtype


FREQUENCIES_TO_ENCODING_UNITS = {
Expand Down Expand Up @@ -1938,4 +1940,5 @@ def test_decode_timedelta_mask_and_scale(
"foo", encoded, mask_and_scale=mask_and_scale, decode_timedelta=decode_timedelta
)
result = conventions.encode_cf_variable(decoded, name="foo")
assert_equal(encoded, result)
assert_identical(encoded, result)
assert encoded.dtype == result.dtype

0 comments on commit 03aa1b8

Please sign in to comment.