diff --git a/py21cmsense/yaml.py b/py21cmsense/yaml.py index 45e7495..331822a 100644 --- a/py21cmsense/yaml.py +++ b/py21cmsense/yaml.py @@ -3,6 +3,7 @@ import numpy as np import pickle import yaml +from astropy import units as un from astropy.io.misc.yaml import AstropyLoader from functools import wraps @@ -18,7 +19,12 @@ class LoadError(IOError): def data_loader(tag=None): - """A decorator that turns a function into a YAML tag for loading external datafiles.""" + """A decorator that turns a function into a YAML tag for loading external datafiles. + + The form of the tag is:: + + ! [| ] + """ def inner(fnc): _DATA_LOADERS[fnc.__name__] = fnc @@ -39,7 +45,18 @@ def wrapper(data): raise LoadError(str(e)) def yaml_fnc(loader, node): - return wrapper(node.value) + args = node.value.split("|") + if len(args) == 1: + return wrapper(node.value) + elif len(args) == 2: + # Return with a unit + return wrapper(args[0].strip()) * getattr(un, args[1].strip()) + else: + raise ValueError( + f"Too many arguments in {new_tag} tag. " + "Should be of the form: " + f"!{new_tag} | " + ) for loader in _YAML_LOADERS: yaml.add_constructor(f"!{new_tag}", yaml_fnc, Loader=loader) diff --git a/tests/test_yaml.py b/tests/test_yaml.py index 9905998..087c62b 100644 --- a/tests/test_yaml.py +++ b/tests/test_yaml.py @@ -2,6 +2,7 @@ import numpy as np import pickle +from astropy import units as un from astropy.io.misc import yaml from py21cmsense.yaml import LoadError @@ -46,3 +47,15 @@ def test_npz_loader(tmpdirec): for k, v in d.items(): assert k in obj assert np.allclose(v, obj[k]) + + +def test_txt_loader_with_unit(tmpdirec): + txt = tmpdirec / "test-txt.txt" + + obj = np.linspace(0, 1, 10) + + np.savetxt(txt, obj) + + d = yaml.load(f"!txt {txt} | m") + assert d.unit == un.m + assert np.allclose(d, obj * un.m)