diff --git a/nc_time_axis/__init__.py b/nc_time_axis/__init__.py index 4ef73a6..0899e86 100644 --- a/nc_time_axis/__init__.py +++ b/nc_time_axis/__init__.py @@ -226,17 +226,24 @@ def default_units(cls, sample_point, axis): Computes some units for the given data point. """ - try: - # Try getting the first item. Otherwise we just use this item. - sample_point = sample_point[0] - except (TypeError, IndexError): - pass - - if not hasattr(sample_point, 'calendar'): - msg = 'Expecting netcdftimes with an extra "calendar" attribute.' - raise ValueError(msg) - - return sample_point.calendar, cls.standard_unit + if hasattr(sample_point, '__iter__'): + # Deal with nD `sample_point` arrays. + if isinstance(sample_point, np.ndarray): + sample_point = sample_point.reshape(-1) + calendars = np.array([point.calendar for point in sample_point]) + if np.all(calendars[0] == calendars): + calendar = calendars[0] + else: + raise ValueError('Calendar units are not all equal.') + else: + # Deal with a single `sample_point` value. + if not hasattr(sample_point, 'calendar'): + msg = ('Expecting netcdftimes with an extra ' + '"calendar" attribute.') + raise ValueError(msg) + else: + calendar = sample_point.calendar + return calendar, cls.standard_unit @classmethod def convert(cls, value, unit, axis): @@ -245,11 +252,13 @@ def convert(cls, value, unit, axis): with :func:`netcdftime.utime().date2num`. """ + shape = None if isinstance(value, np.ndarray): # Don't do anything with numeric types. if value.dtype != np.object: return value - + shape = value.shape + value = value.reshape(-1) first_value = value[0] else: # Don't do anything with numeric types. @@ -270,7 +279,11 @@ def convert(cls, value, unit, axis): if isinstance(value, CalendarDateTime): value = [value] - return ut.date2num([v.datetime for v in value]) + result = ut.date2num([v.datetime for v in value]) + if shape is not None: + result = result.reshape(shape) + + return result # Automatically register NetCDFTimeConverter with matplotlib.unit's converter diff --git a/nc_time_axis/tests/unit/test_NetCDFTimeConverter.py b/nc_time_axis/tests/unit/test_NetCDFTimeConverter.py index e176bf0..eabc0dd 100644 --- a/nc_time_axis/tests/unit/test_NetCDFTimeConverter.py +++ b/nc_time_axis/tests/unit/test_NetCDFTimeConverter.py @@ -24,13 +24,41 @@ def test_axis_default_limits(self): class Test_default_units(unittest.TestCase): - def test_360_day_calendar(self): + def test_360_day_calendar_point(self): + calendar = '360_day' + unit = 'days since 2000-01-01' + val = CalendarDateTime(netcdftime.datetime(2014, 8, 12), calendar) + result = NetCDFTimeConverter().default_units(val, None) + self.assertEqual(result, (calendar, unit)) + + def test_360_day_calendar_list(self): calendar = '360_day' unit = 'days since 2000-01-01' val = [CalendarDateTime(netcdftime.datetime(2014, 8, 12), calendar)] result = NetCDFTimeConverter().default_units(val, None) self.assertEqual(result, (calendar, unit)) + def test_360_day_calendar_nd(self): + # Test the case where the input is an nd-array. + calendar = '360_day' + unit = 'days since 2000-01-01' + val = np.array([[CalendarDateTime(netcdftime.datetime(2014, 8, 12), + calendar)], + [CalendarDateTime(netcdftime.datetime(2014, 8, 13), + calendar)]]) + result = NetCDFTimeConverter().default_units(val, None) + self.assertEqual(result, (calendar, unit)) + + def test_nonequal_calendars(self): + # Test that different supplied calendars causes an error. + calendar_1 = '360_day' + calendar_2 = '365_day' + unit = 'days since 2000-01-01' + val = [CalendarDateTime(netcdftime.datetime(2014, 8, 12), calendar_1), + CalendarDateTime(netcdftime.datetime(2014, 8, 13), calendar_2)] + with self.assertRaisesRegexp(ValueError, 'not all equal'): + NetCDFTimeConverter().default_units(val, None) + class Test_convert(unittest.TestCase): def test_numpy_array(self): @@ -38,6 +66,13 @@ def test_numpy_array(self): result = NetCDFTimeConverter().convert(val, None, None) np.testing.assert_array_equal(result, val) + def test_numpy_nd_array(self): + shape = (4, 2) + val = np.arange(8).reshape(shape) + result = NetCDFTimeConverter().convert(val, None, None) + np.testing.assert_array_equal(result, val) + self.assertEqual(result.shape, shape) + def test_numeric(self): val = 4 result = NetCDFTimeConverter().convert(val, None, None)