Skip to content

Commit

Permalink
Add support for masked quantities
Browse files Browse the repository at this point in the history
Fixes #202.
  • Loading branch information
lpsinger committed Jan 28, 2025
1 parent 6826f49 commit 434b4c0
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 2 deletions.
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
0.8.0 (unreleased)
------------------

- Add support for masked quantities [#749]

0.7.0 (2024-11-13)
------------------

Expand Down
12 changes: 10 additions & 2 deletions asdf_astropy/converters/unit/quantity.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import numpy as np
from asdf.extension import Converter
from asdf.tags.core.ndarray import NDArrayType
from astropy.units import Quantity
from astropy.utils.masked import Masked

MaskedQuantity = Masked(Quantity)


class QuantityConverter(Converter):
Expand All @@ -9,11 +14,12 @@ class QuantityConverter(Converter):
# The Distance class has no tag of its own, so we
# just serialize it as a quantity.
"astropy.coordinates.distances.Distance",
MaskedQuantity,
)

def to_yaml_tree(self, obj, tag, ctx):
node = {
"value": obj.value,
"value": np.ma.asarray(obj.value) if isinstance(obj, MaskedQuantity) else obj.value,
"unit": obj.unit,
}

Expand All @@ -39,4 +45,6 @@ def from_yaml_tree(self, node, tag, ctx):
value = value._make_array()
dtype = value.dtype

return Quantity(value, unit=node["unit"], copy=copy, dtype=dtype)
class_ = MaskedQuantity if isinstance(value, np.ma.MaskedArray) else Quantity

return class_(value, unit=node["unit"], copy=copy, dtype=dtype)
22 changes: 22 additions & 0 deletions asdf_astropy/converters/unit/tests/test_masked.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import asdf
import numpy as np
from astropy import units as u
from astropy.utils.masked import Masked, get_data_and_mask

MaskedQuantity = Masked(u.Quantity)

Check warning on line 6 in asdf_astropy/converters/unit/tests/test_masked.py

View check run for this annotation

Codecov / codecov/patch

asdf_astropy/converters/unit/tests/test_masked.py#L6

Added line #L6 was not covered by tests


def test_masked_quantity(tmp_path):
data = [1, 2, 3]
mask = [False, False, True]
file_path = tmp_path / "test.asdf"
with asdf.AsdfFile() as af:
af["quantity"] = Masked(data, mask) * u.yottamole
af.write_to(file_path)

Check warning on line 15 in asdf_astropy/converters/unit/tests/test_masked.py

View check run for this annotation

Codecov / codecov/patch

asdf_astropy/converters/unit/tests/test_masked.py#L9-L15

Added lines #L9 - L15 were not covered by tests

with asdf.open(file_path) as af:
assert isinstance(af["quantity"], MaskedQuantity)
assert af["quantity"].unit == u.yottamole
result_data, result_mask = get_data_and_mask(af["quantity"].value)
np.testing.assert_array_equal(result_data, data)
np.testing.assert_array_equal(result_mask, mask)

Check warning on line 22 in asdf_astropy/converters/unit/tests/test_masked.py

View check run for this annotation

Codecov / codecov/patch

asdf_astropy/converters/unit/tests/test_masked.py#L17-L22

Added lines #L17 - L22 were not covered by tests

0 comments on commit 434b4c0

Please sign in to comment.