Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NF: add string and slice slicing of AxesManager #68

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 39 additions & 13 deletions datarray/datarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,26 @@ class AxesManager(object):
DataArray(array(True, dtype=bool),
('date', ('stocks', ('aapl', 'ibm')), 'metric'))


Axes can also be accessed numerically:
Axes can be accessed numerically:

>>> A.axes[1] is A.axes.stocks
True

Calling the AxesManager with string arguments will return an
The axis name can be used as an index, as well as an attribute:

>>> A.axes['stocks'] is A.axes.stocks
True

Axes can also be sliced:

>>> A.axes[1:]
(Axis(name='stocks', index=1, labels=('aapl', 'ibm', 'goog', 'msft')), Axis(name='metric', index=2, labels=None))

*Calling* the AxesManager with string arguments will return an
:py:class:`AxisIndexer` object which can be used to restrict slices to
specified axes:

>>> Ai = A.axes('stocks', 'date')
>>> Ai = A.axes('stocks', 'date') # Note the parens
>>> np.all(Ai['aapl':'goog', 100] == A[100, 0:2])
DataArray(array(True, dtype=bool),
(('stocks', ('aapl', 'ibm')), 'metric'))
Expand Down Expand Up @@ -156,20 +165,37 @@ def __getitem__(self, n):

Parameters
----------
n : int
Index of axis to be returned.
n : int or string or slice
If int, index of axis to be returned. If string, name of axis to
be returned. If slice object, slice from AxesManager to return.

Returns
-------
The requested :py:class:`Axis`.

ax : Axis or AxesManager
The requested :py:class:`Axis` if `n` is an int or string. A new
AxesManager object if `n` is a slice.
"""
if not isinstance(n, int):
raise TypeError("AxesManager expects integer index")
axes = object.__getattribute__(self, '_axes')
if isinstance(n, int): # Integer slicing retuns Axis
try:
return axes[n]
except IndexError:
raise IndexError("Requested axis %i out of bounds" % n)
# Indexing by name returns Axis
namemap = object.__getattribute__(self, '_namemap')
try:
n = namemap[n]
except TypeError:
pass
else:
return axes[n]
# Indexing with slice object returns new AxesManager
try:
return object.__getattribute__(self, '_axes')[n]
except IndexError:
raise IndexError("Requested axis %i out of bounds" % n)
new_axes = axes[n]
except TypeError:
raise TypeError("Invalid axis index {0}".format(n))
arr = object.__getattribute__(self, '_arr')
return type(self)(arr, new_axes)

def __eq__(self, other):
"""Test for equality between two axes managers. Two axes managers are
Expand Down
43 changes: 28 additions & 15 deletions datarray/tests/test_data_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import numpy as np

from datarray.datarray import Axis, DataArray, NamedAxisError, \
_pull_axis, _reordered_axes
from datarray.datarray import (Axis, AxesManager, DataArray, NamedAxisError,
_pull_axis, _reordered_axes)

import nose.tools as nt
import numpy.testing as npt
Expand Down Expand Up @@ -461,6 +461,14 @@ class TestAxesManager:
def setUp(self):
self.axes_spec = ('date', ('stocks', ('aapl', 'ibm', 'goog', 'msft')), 'metric')
self.A = DataArray(np.random.randn(200, 4, 10), axes=self.axes_spec)
self.axes = []
for i, spec in enumerate(self.axes_spec):
try:
name, labels = spec
except ValueError:
name, labels = spec, None
self.axes.append(
Axis(name=name, index=i, parent_arr=self.A, labels=labels))

def test_axes_name_collision(self):
"Test .axes object for attribute collisions with axis names"
Expand All @@ -475,21 +483,26 @@ def test_axes_name_collision(self):
nt.assert_equal(B.shape, (1,1,2,3))
nt.assert_true(np.all(A + A == 2*A))

def test_axes_numeric_access(self):
for i,spec in enumerate(self.axes_spec):
try:
name,labels = spec
except ValueError:
name,labels = spec,None
nt.assert_true(self.A.axes[i] == Axis(name=name, index=i,
parent_arr=self.A, labels=labels))
def test_axes_indexing(self):
n_axes = len(self.axes)
for i, exp_axis in enumerate(self.axes):
# Index with integer
nt.assert_equal(self.A.axes[i], exp_axis)
# Negative integer
nt.assert_equal(self.A.axes[i - n_axes], exp_axis)
# Name
nt.assert_equal(self.A.axes[exp_axis.name], exp_axis)
# Single element slice
one_axis = self.A.axes[i:i + 1]
nt.assert_equal(len(one_axis), 1)
nt.assert_equal(one_axis[0], exp_axis)
# Slice with more than one element
nt.assert_equal(self.A.axes[1:],
AxesManager(np.array(self.A), self.axes[1:]))

def test_axes_attribute_access(self):
for spec in self.axes_spec:
try:
name,labels = spec
except ValueError:
name,labels = spec,None
for axis in self.axes:
name = axis.name
nt.assert_true(getattr(self.A.axes, name) is self.A.axes(name))

def test_equality(self):
Expand Down