diff --git a/nc_time_axis/__init__.py b/nc_time_axis/__init__.py index 4ea7850..ccb8d87 100644 --- a/nc_time_axis/__init__.py +++ b/nc_time_axis/__init__.py @@ -209,14 +209,20 @@ def axisinfo(unit, axis): *unit* is a tzinfo instance or None. The *axis* argument is required but not used. """ - calendar, date_unit = unit + calendar, date_unit, date_type = unit majloc = NetCDFTimeDateLocator(4, calendar=calendar, date_unit=date_unit) majfmt = NetCDFTimeDateFormatter(majloc, calendar=calendar, time_units=date_unit) - datemin = CalendarDateTime(cftime.datetime(2000, 1, 1), calendar) - datemax = CalendarDateTime(cftime.datetime(2010, 1, 1), calendar) + if date_type is CalendarDateTime: + datemin = CalendarDateTime(cftime.datetime(2000, 1, 1), + calendar=calendar) + datemax = CalendarDateTime(cftime.datetime(2010, 1, 1), + calendar=calendar) + else: + datemin = date_type(2000, 1, 1) + datemax = date_type(2010, 1, 1) return munits.AxisInfo(majloc=majloc, majfmt=majfmt, label='', default_limits=(datemin, datemax)) @@ -235,6 +241,7 @@ def default_units(cls, sample_point, axis): calendar = calendars[0] else: raise ValueError('Calendar units are not all equal.') + date_type = type(sample_point[0]) else: # Deal with a single `sample_point` value. if not hasattr(sample_point, 'calendar'): @@ -243,7 +250,8 @@ def default_units(cls, sample_point, axis): raise ValueError(msg) else: calendar = sample_point.calendar - return calendar, cls.standard_unit + date_type = type(sample_point) + return calendar, cls.standard_unit, date_type @classmethod def convert(cls, value, unit, axis): @@ -266,20 +274,27 @@ def convert(cls, value, unit, axis): return value first_value = value - if not isinstance(first_value, CalendarDateTime): + if not isinstance(first_value, (CalendarDateTime, cftime.datetime)): raise ValueError('The values must be numbers or instances of ' - '"nc_time_axis.CalendarDateTime".') + '"nc_time_axis.CalendarDateTime" or ' + '"cftime.datetime".') - if not isinstance(first_value.datetime, cftime.datetime): - raise ValueError('The datetime attribute of the CalendarDateTime ' - 'object must be of type `cftime.datetime`.') + if isinstance(first_value, CalendarDateTime): + if not isinstance(first_value.datetime, cftime.datetime): + raise ValueError('The datetime attribute of the ' + 'CalendarDateTime object must be of type ' + '`cftime.datetime`.') ut = cftime.utime(cls.standard_unit, calendar=first_value.calendar) - if isinstance(value, CalendarDateTime): + if isinstance(value, (CalendarDateTime, cftime.datetime)): value = [value] - result = ut.date2num([v.datetime for v in value]) + if isinstance(first_value, CalendarDateTime): + result = ut.date2num([v.datetime for v in value]) + else: + result = ut.date2num(value) + if shape is not None: result = result.reshape(shape) @@ -290,3 +305,10 @@ def convert(cls, value, unit, axis): # dictionary. if CalendarDateTime not in munits.registry: munits.registry[CalendarDateTime] = NetCDFTimeConverter() + +CFTIME_TYPES = [cftime.DatetimeNoLeap, cftime.DatetimeAllLeap, + cftime.DatetimeProlepticGregorian, cftime.DatetimeGregorian, + cftime.Datetime360Day, cftime.DatetimeJulian] +for date_type in CFTIME_TYPES: + if date_type not in munits.registry: + munits.registry[date_type] = NetCDFTimeConverter() diff --git a/nc_time_axis/tests/integration/test_plot.py b/nc_time_axis/tests/integration/test_plot.py index 4f95833..ba1d5b7 100644 --- a/nc_time_axis/tests/integration/test_plot.py +++ b/nc_time_axis/tests/integration/test_plot.py @@ -25,7 +25,7 @@ def tearDown(self): # in an odd state, so we make sure it's been disposed of. plt.close('all') - def test_360_day_calendar(self): + def test_360_day_calendar_CalendarDateTime(self): datetimes = [cftime.datetime(1986, month, 30) for month in range(1, 6)] cal_datetimes = [nc_time_axis.CalendarDateTime(dt, '360_day') @@ -34,6 +34,13 @@ def test_360_day_calendar(self): result_ydata = line1.get_ydata() np.testing.assert_array_equal(result_ydata, cal_datetimes) + def test_360_day_calendar_raw_dates(self): + datetimes = [cftime.Datetime360Day(1986, month, 30) + for month in range(1, 6)] + line1, = plt.plot(datetimes) + result_ydata = line1.get_ydata() + np.testing.assert_array_equal(result_ydata, datetimes) + if __name__ == "__main__": unittest.main() diff --git a/nc_time_axis/tests/unit/test_NetCDFTimeConverter.py b/nc_time_axis/tests/unit/test_NetCDFTimeConverter.py index da6634c..5d5edaf 100644 --- a/nc_time_axis/tests/unit/test_NetCDFTimeConverter.py +++ b/nc_time_axis/tests/unit/test_NetCDFTimeConverter.py @@ -15,7 +15,7 @@ class Test_axisinfo(unittest.TestCase): def test_axis_default_limits(self): cal = '360_day' - unit = (cal, 'days since 2000-02-25 00:00:00') + unit = (cal, 'days since 2000-02-25 00:00:00', CalendarDateTime) result = NetCDFTimeConverter().axisinfo(unit, None) expected_dt = [cftime.datetime(2000, 1, 1), cftime.datetime(2010, 1, 1)] @@ -25,21 +25,21 @@ def test_axis_default_limits(self): class Test_default_units(unittest.TestCase): - def test_360_day_calendar_point(self): + def test_360_day_calendar_point_CalendarDateTime(self): calendar = '360_day' unit = 'days since 2000-01-01' val = CalendarDateTime(cftime.datetime(2014, 8, 12), calendar) result = NetCDFTimeConverter().default_units(val, None) - self.assertEqual(result, (calendar, unit)) + self.assertEqual(result, (calendar, unit, CalendarDateTime)) - def test_360_day_calendar_list(self): + def test_360_day_calendar_list_CalendarDateTime(self): calendar = '360_day' unit = 'days since 2000-01-01' val = [CalendarDateTime(cftime.datetime(2014, 8, 12), calendar)] result = NetCDFTimeConverter().default_units(val, None) - self.assertEqual(result, (calendar, unit)) + self.assertEqual(result, (calendar, unit, CalendarDateTime)) - def test_360_day_calendar_nd(self): + def test_360_day_calendar_nd_CalendarDateTime(self): # Test the case where the input is an nd-array. calendar = '360_day' unit = 'days since 2000-01-01' @@ -48,7 +48,30 @@ def test_360_day_calendar_nd(self): [CalendarDateTime(cftime.datetime(2014, 8, 13), calendar)]]) result = NetCDFTimeConverter().default_units(val, None) - self.assertEqual(result, (calendar, unit)) + self.assertEqual(result, (calendar, unit, CalendarDateTime)) + + def test_360_day_calendar_point_raw_date(self): + calendar = '360_day' + unit = 'days since 2000-01-01' + val = cftime.Datetime360Day(2014, 8, 12) + result = NetCDFTimeConverter().default_units(val, None) + self.assertEqual(result, (calendar, unit, cftime.Datetime360Day)) + + def test_360_day_calendar_list_raw_date(self): + calendar = '360_day' + unit = 'days since 2000-01-01' + val = [cftime.Datetime360Day(2014, 8, 12)] + result = NetCDFTimeConverter().default_units(val, None) + self.assertEqual(result, (calendar, unit, cftime.Datetime360Day)) + + def test_360_day_calendar_nd_raw_date(self): + # Test the case where the input is an nd-array. + calendar = '360_day' + unit = 'days since 2000-01-01' + val = np.array([[cftime.Datetime360Day(2014, 8, 12)], + [cftime.Datetime360Day(2014, 8, 13)]]) + result = NetCDFTimeConverter().default_units(val, None) + self.assertEqual(result, (calendar, unit, cftime.Datetime360Day)) def test_nonequal_calendars(self): # Test that different supplied calendars causes an error. @@ -84,17 +107,27 @@ def test_numeric_iterable(self): result = NetCDFTimeConverter().convert(val, None, None) np.testing.assert_array_equal(result, val) - def test_cftime(self): + def test_cftime_CalendarDateTime(self): val = CalendarDateTime(cftime.datetime(2014, 8, 12), '365_day') result = NetCDFTimeConverter().convert(val, None, None) np.testing.assert_array_equal(result, 5333.) - def test_cftime_np_array(self): + def test_cftime_raw_date(self): + val = cftime.DatetimeNoLeap(2014, 8, 12) + result = NetCDFTimeConverter().convert(val, None, None) + np.testing.assert_array_equal(result, 5333.) + + def test_cftime_np_array_CalendarDateTime(self): val = np.array([CalendarDateTime(cftime.datetime(2012, 6, 4), '360_day')], dtype=np.object) result = NetCDFTimeConverter().convert(val, None, None) self.assertEqual(result, np.array([4473.])) + def test_cftime_np_array_raw_date(self): + val = np.array([cftime.Datetime360Day(2012, 6, 4)], dtype=np.object) + result = NetCDFTimeConverter().convert(val, None, None) + self.assertEqual(result, np.array([4473.])) + def test_non_cftime_datetime(self): val = CalendarDateTime(4, '360_day') msg = 'The datetime attribute of the CalendarDateTime object must ' \ @@ -103,7 +136,7 @@ def test_non_cftime_datetime(self): result = NetCDFTimeConverter().convert(val, None, None) def test_non_CalendarDateTime(self): - val = cftime.datetime(1988, 5, 6) + val = 'test' msg = 'The values must be numbers or instances of ' \ '"nc_time_axis.CalendarDateTime".' with assertRaisesRegex(self, ValueError, msg):