Skip to content

Commit

Permalink
refactor: use helper functions to handle DataArray attributes wheneve…
Browse files Browse the repository at this point in the history
…r possible
  • Loading branch information
mtrocadomoreira committed Oct 21, 2024
1 parent 17dff95 commit 5d53830
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 24 deletions.
7 changes: 4 additions & 3 deletions src/ozzy/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,10 @@ def _coord_to_physical_distance(
def _save(instance, path):
dobj = instance._obj
for metadata in ["pic_data_type", "data_origin"]:
if metadata in dobj.attrs:
if dobj.attrs[metadata] is None:
dobj.attrs[metadata] = ""
dobj = set_attr_if_exists(dobj, metadata, str_doesnt="")
# if metadata in dobj.attrs:
# if dobj.attrs[metadata] is None:
# dobj.attrs[metadata] = ""

instance._obj.to_netcdf(path, engine="h5netcdf", compute=True, invalid_netcdf=True)
print(' -> Saved file "' + path + '" ')
Expand Down
30 changes: 13 additions & 17 deletions src/ozzy/part_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
# [email protected]
# *********************************************************


import numpy as np
import xarray as xr
from flox.xarray import xarray_reduce

from .new_dataobj import new_dataset
from .statistics import parts_into_grid
from .utils import axis_from_extent, bins_from_axis
from .utils import axis_from_extent, bins_from_axis, get_attr_if_exists


class PartMixin:
Expand Down Expand Up @@ -234,28 +233,25 @@ def mean_std(
)
result = result.rename({dim + "_w": dim + "_mean"})

if "long_name" in self._obj[dim].attrs:
newlname = "mean(" + self._obj[dim].attrs["long_name"] + ")"
else:
newlname = "mean"
newlname = get_attr_if_exists(
self._obj[dim], "long_name", lambda x: f"mean({x})", "mean"
)
result[dim + "_mean"].attrs["long_name"] = newlname

if "units" in self._obj[dim].attrs:
result[dim + "_mean"].attrs["units"] = self._obj[dim].attrs["units"]
newunits = get_attr_if_exists(self._obj[dim], "units")
if newunits is not None:
result[dim + "_mean"].attrs["units"] = newunits

else:
result[dim + "_std"] = np.sqrt(result[dim + "_sqw"])

# TODO: make function to use long_name and units if existing, otherwise replace with something

if "long_name" in self._obj[dim].attrs:
newlname = "std(" + self._obj[dim].attrs["long_name"] + ")"
else:
newlname = "std"
result[dim + "_std"].attrs["long_name"] = newlname
result[dim + "_std"].attrs["long_name"] = get_attr_if_exists(
self._obj[dim], "long_name", lambda x: f"std({x})", "std"
)

if "units" in self._obj[dim].attrs:
result[dim + "_std"].attrs["units"] = self._obj[dim].attrs["units"]
newunits = get_attr_if_exists(self._obj[dim], "units")
if newunits is not None:
result[dim + "_std"].attrs["units"] = newunits

result = result.drop_vars(dim + "_sqw")

Expand Down
11 changes: 7 additions & 4 deletions src/ozzy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,23 +756,24 @@ def bins_from_axis(axis: np.ndarray) -> np.ndarray:
def set_attr_if_exists(
da: xr.DataArray,
attr: str,
str_exists: str | Iterable[str] | Callable,
str_exists: str | Iterable[str] | Callable | None,
str_doesnt: str | None = None,
):
"""
Set or modify an attribute of a [DataArray][ozzy.core.DataArray] if it exists.
Set or modify an attribute of a [DataArray][ozzy.core.DataArray] if it exists, or modify if it doesn't exist or is `None`.
Parameters
----------
da : xarray.DataArray
The input DataArray.
attr : str
The name of the attribute to set or modify.
str_exists : str | Iterable[str] | Callable
str_exists : str | Iterable[str] | Callable | None, optional
The value or function to use if the attribute exists.
If `str`: replace the attribute with this string.
If [`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable): concatenate the first element, existing value, and second element.
If [`Callable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Callable): apply this function to the existing attribute value.
If `None`: do not change the attribute.
str_doesnt : str | None, optional
The value to set if the attribute doesn't exist. If `None`, no action is taken.
Expand Down Expand Up @@ -830,7 +831,7 @@ def set_attr_if_exists(
# Output: unknown
```
"""
if attr in da.attrs:
if (attr in da.attrs) and (da.attrs[attr] is not None):
if isinstance(str_exists, str):
da.attrs[attr] = str_exists
elif isinstance(str_exists, Iterable):
Expand All @@ -839,6 +840,8 @@ def set_attr_if_exists(
da.attrs[attr] = str_exists[0] + da.attrs[attr] + str_exists[1]
elif isinstance(str_exists, Callable):
da.attrs[attr] = str_exists(da.attrs[attr])
elif str_exists is None:
return da
else:
if str_doesnt is not None:
da.attrs[attr] = str_doesnt
Expand Down

0 comments on commit 5d53830

Please sign in to comment.