Skip to content

Commit

Permalink
feat: add ability to specify unit type for data loader (#87)
Browse files Browse the repository at this point in the history
* feat: add ability to specify unit type for data loader

* trigger pre-commit
  • Loading branch information
steven-murray authored Nov 13, 2023
1 parent dfe1b63 commit 6e90fbd
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
21 changes: 19 additions & 2 deletions py21cmsense/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pickle
import yaml
from astropy import units as un
from astropy.io.misc.yaml import AstropyLoader
from functools import wraps

Expand All @@ -18,7 +19,12 @@ class LoadError(IOError):


def data_loader(tag=None):
"""A decorator that turns a function into a YAML tag for loading external datafiles."""
"""A decorator that turns a function into a YAML tag for loading external datafiles.
The form of the tag is::
!<tag> <path_to_data> [| <unit>]
"""

def inner(fnc):
_DATA_LOADERS[fnc.__name__] = fnc
Expand All @@ -39,7 +45,18 @@ def wrapper(data):
raise LoadError(str(e))

def yaml_fnc(loader, node):
return wrapper(node.value)
args = node.value.split("|")
if len(args) == 1:
return wrapper(node.value)
elif len(args) == 2:
# Return with a unit
return wrapper(args[0].strip()) * getattr(un, args[1].strip())
else:
raise ValueError(
f"Too many arguments in {new_tag} tag. "
"Should be of the form: "
f"!{new_tag} <path_to_data> | <unit>"
)

for loader in _YAML_LOADERS:
yaml.add_constructor(f"!{new_tag}", yaml_fnc, Loader=loader)
Expand Down
13 changes: 13 additions & 0 deletions tests/test_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pickle
from astropy import units as un
from astropy.io.misc import yaml

from py21cmsense.yaml import LoadError
Expand Down Expand Up @@ -46,3 +47,15 @@ def test_npz_loader(tmpdirec):
for k, v in d.items():
assert k in obj
assert np.allclose(v, obj[k])


def test_txt_loader_with_unit(tmpdirec):
txt = tmpdirec / "test-txt.txt"

obj = np.linspace(0, 1, 10)

np.savetxt(txt, obj)

d = yaml.load(f"!txt {txt} | m")
assert d.unit == un.m
assert np.allclose(d, obj * un.m)

0 comments on commit 6e90fbd

Please sign in to comment.