Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gp/feat/aman arbitrary path #1057

Merged
merged 10 commits into from
Jan 10, 2025
Merged
11 changes: 11 additions & 0 deletions docs/axisman.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,17 @@ The output of the ``wrap`` cal should be::
Note the boresight entry is marked with a ``*``, indicating that it's
an AxisManager rather than a numpy array.

To access data in an AxisManager, use a path-like syntax where
iparask marked this conversation as resolved.
Show resolved Hide resolved
attribute names are separated by dots::

>>> n, ofs = 1000, 0
>>> dets = ["det0", "det1", "det2"]
>>> aman = core.AxisManager(core.LabelAxis("dets", dets), core.OffsetAxis("samps", n, ofs))
>>> child = core.AxisManager(core.LabelAxis("dets", dets + ["det3"]),core.OffsetAxis("samps", n, ofs - n // 2),)
>>> aman.wrap("child", child)
>>> print(aman["child.dets"])
LabelAxis(3:'det0','det1','det2')

To slice this object, use the restrict() method. First, let's
restrict in the 'dets' axis. Since it's an Axis of type LabelAxis,
the restriction selector must be a list of strings::
Expand Down
65 changes: 50 additions & 15 deletions sotodlib/core/axisman.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,29 +349,60 @@ def move(self, name, new_name):
self._fields[new_name] = self._fields.pop(name)
self._assignments[new_name] = self._assignments.pop(name)
return self

def add_axis(self, a):
assert isinstance( a, AxisInterface)
self._axes[a.name] = a.copy()

def __contains__(self, name):
return name in self._fields or name in self._axes
attrs = name.split(".")
tmp_item = self
while attrs:
attr_name = attrs.pop(0)
if attr_name in tmp_item._fields:
tmp_item = tmp_item._fields[attr_name]
elif attr_name in tmp_item._axes:
tmp_item = tmp_item._axes[attr_name]
else:
return False
return True

def __getitem__(self, name):
if name in self._fields:
return self._fields[name]
if name in self._axes:
return self._axes[name]
raise KeyError(name)

# We want to support options like:
# aman.focal_plane.xi . aman['focal_plane.xi']
# We will safely assume that a getitem will always have '.' as the separator
attrs = name.split(".")
tmp_item = self
while attrs:
attr_name = attrs.pop(0)
if attr_name in tmp_item._fields:
tmp_item = tmp_item._fields[attr_name]
elif attr_name in tmp_item._axes:
tmp_item = tmp_item._axes[attr_name]
else:
raise KeyError(attr_name)
return tmp_item

def __setitem__(self, name, val):
if name in self._fields:
self._fields[name] = val

last_pos = name.rfind(".")
val_key = name
tmp_item = self
if last_pos > -1:
val_key = name[last_pos + 1:]
attrs = name[:last_pos]
tmp_item = self[attrs]

if val_key in tmp_item._fields:
tmp_item._fields[val_key] = val
else:
raise KeyError(name)

def __setattr__(self, name, value):
# Assignment to members update those members
# We will assume that a path exists until the last member.
# If any member prior to that does not exist a keyerror is raised.
if "_fields" in self.__dict__ and name in self._fields.keys():
self._fields[name] = value
else:
Expand All @@ -381,7 +412,11 @@ def __setattr__(self, name, value):
def __getattr__(self, name):
# Prevent members from override special class members.
if name.startswith("__"): raise AttributeError(name)
return self[name]
try:
val = self[name]
except KeyError as ex:
raise AttributeError(name) from ex
return val

def __dir__(self):
return sorted(tuple(self.__dict__.keys()) + tuple(self.keys()))
Expand Down Expand Up @@ -514,27 +549,27 @@ def concatenate(items, axis=0, other_fields='exact'):
output.wrap(k, new_data[k], axis_map)
else:
if other_fields == "exact":
## if every item named k is a scalar
## if every item named k is a scalar
err_msg = (f"The field '{k}' does not share axis '{axis}'; "
f"{k} is not identical across all items "
f"pass other_fields='drop' or 'first' or else "
f"remove this field from the targets.")

if np.any([np.isscalar(i[k]) for i in items]):
if not np.all([np.isscalar(i[k]) for i in items]):
raise ValueError(err_msg)
if not np.all([np.array_equal(i[k], items[0][k], equal_nan=True) for i in items]):
raise ValueError(err_msg)
output.wrap(k, items[0][k], axis_map)
continue

elif not np.all([i[k].shape==items[0][k].shape for i in items]):
raise ValueError(err_msg)
elif not np.all([np.array_equal(i[k], items[0][k], equal_nan=True) for i in items]):
raise ValueError(err_msg)

output.wrap(k, items[0][k].copy(), axis_map)

elif other_fields == 'fail':
raise ValueError(
f"The field '{k}' does not share axis '{axis}'; "
Expand Down
48 changes: 43 additions & 5 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_130_not_inplace(self):

# This should return a separate thing.
rman = aman.restrict('samps', (10, 30), in_place=False)
#self.assertNotEqual(aman.a1[0], 0.)
# self.assertNotEqual(aman.a1[0], 0.)
self.assertEqual(len(aman.a1), 100)
self.assertEqual(len(rman.a1), 20)
self.assertNotEqual(aman.a1[10], 0.)
Expand Down Expand Up @@ -190,23 +190,23 @@ def test_170_concat(self):

# ... other_fields="exact"
aman = core.AxisManager.concatenate([amanA, amanB], axis='dets')

## add scalars
amanA.wrap("ans", 42)
amanB.wrap("ans", 42)
aman = core.AxisManager.concatenate([amanA, amanB], axis='dets')

# ... other_fields="exact"
amanB.azimuth[:] = 2.
with self.assertRaises(ValueError):
aman = core.AxisManager.concatenate([amanA, amanB], axis='dets')

# ... other_fields="exact" and arrays of different shapes
amanB.move("azimuth", None)
amanB.wrap("azimuth", np.array([43,5,2,3]))
with self.assertRaises(ValueError):
aman = core.AxisManager.concatenate([amanA, amanB], axis='dets')

# ... other_fields="fail"
amanB.move("azimuth",None)
amanB.wrap_new('azimuth', shape=('samps',))[:] = 2.
Expand Down Expand Up @@ -298,6 +298,44 @@ def test_300_restrict(self):
self.assertNotEqual(aman.a1[0, 0, 0, 1], 0.)

# wrap of AxisManager, merge.
def test_get_set(self):
iparask marked this conversation as resolved.
Show resolved Hide resolved
dets = ["det0", "det1", "det2"]
n, ofs = 1000, 0
aman = core.AxisManager(
core.LabelAxis("dets", dets), core.OffsetAxis("samps", n, ofs)
)
child = core.AxisManager(
core.LabelAxis("dets", dets + ["det3"]),
core.OffsetAxis("samps", n, ofs - n // 2),
)

child2 = core.AxisManager(
core.LabelAxis("dets2", ["det4", "det5"]),
core.OffsetAxis("samps", n, ofs - n // 2),
)
aman.wrap("child", child)
aman["child"].wrap("child2", child2)
iparask marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(aman["child.child2.dets2"].count, 2)
self.assertEqual(aman["child.dets"].name, "dets")
np.testing.assert_array_equal(aman["child.child2.dets2"].vals, np.array(["det4", "det5"]))
with self.assertRaises(KeyError):
aman["child.someentry"]

with self.assertRaises(KeyError):
aman["child2"]

with self.assertRaises(AttributeError):
aman["child.dets.an_extra_layer"]

self.assertIn("child.dets", aman)
self.assertIn("child.dets2", aman) # I am not sure why this is true
iparask marked this conversation as resolved.
Show resolved Hide resolved
self.assertNotIn("child.child2.someentry", aman)
iparask marked this conversation as resolved.
Show resolved Hide resolved

aman["child"] = child2
iparask marked this conversation as resolved.
Show resolved Hide resolved
print(aman["child"])
self.assertEqual(aman["child.dets2"].count, 2)
self.assertEqual(aman["child.dets2"].name, "dets2")
np.testing.assert_array_equal(aman["child.dets2"].vals, np.array(["det4", "det5"]))

def test_400_child(self):
dets = ['det0', 'det1', 'det2']
Expand Down
Loading