Skip to content

Commit 531dd0d

Browse files
spencerkclarklbdreyer
authored andcommitted
Add support for plotting subclasses of cftime.datetime (#42)
1 parent 46aee1a commit 531dd0d

File tree

3 files changed

+84
-22
lines changed

3 files changed

+84
-22
lines changed

nc_time_axis/__init__.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -209,14 +209,20 @@ def axisinfo(unit, axis):
209209
*unit* is a tzinfo instance or None.
210210
The *axis* argument is required but not used.
211211
"""
212-
calendar, date_unit = unit
212+
calendar, date_unit, date_type = unit
213213

214214
majloc = NetCDFTimeDateLocator(4, calendar=calendar,
215215
date_unit=date_unit)
216216
majfmt = NetCDFTimeDateFormatter(majloc, calendar=calendar,
217217
time_units=date_unit)
218-
datemin = CalendarDateTime(cftime.datetime(2000, 1, 1), calendar)
219-
datemax = CalendarDateTime(cftime.datetime(2010, 1, 1), calendar)
218+
if date_type is CalendarDateTime:
219+
datemin = CalendarDateTime(cftime.datetime(2000, 1, 1),
220+
calendar=calendar)
221+
datemax = CalendarDateTime(cftime.datetime(2010, 1, 1),
222+
calendar=calendar)
223+
else:
224+
datemin = date_type(2000, 1, 1)
225+
datemax = date_type(2010, 1, 1)
220226
return munits.AxisInfo(majloc=majloc, majfmt=majfmt, label='',
221227
default_limits=(datemin, datemax))
222228

@@ -235,6 +241,7 @@ def default_units(cls, sample_point, axis):
235241
calendar = calendars[0]
236242
else:
237243
raise ValueError('Calendar units are not all equal.')
244+
date_type = type(sample_point[0])
238245
else:
239246
# Deal with a single `sample_point` value.
240247
if not hasattr(sample_point, 'calendar'):
@@ -243,7 +250,8 @@ def default_units(cls, sample_point, axis):
243250
raise ValueError(msg)
244251
else:
245252
calendar = sample_point.calendar
246-
return calendar, cls.standard_unit
253+
date_type = type(sample_point)
254+
return calendar, cls.standard_unit, date_type
247255

248256
@classmethod
249257
def convert(cls, value, unit, axis):
@@ -266,20 +274,27 @@ def convert(cls, value, unit, axis):
266274
return value
267275
first_value = value
268276

269-
if not isinstance(first_value, CalendarDateTime):
277+
if not isinstance(first_value, (CalendarDateTime, cftime.datetime)):
270278
raise ValueError('The values must be numbers or instances of '
271-
'"nc_time_axis.CalendarDateTime".')
279+
'"nc_time_axis.CalendarDateTime" or '
280+
'"cftime.datetime".')
272281

273-
if not isinstance(first_value.datetime, cftime.datetime):
274-
raise ValueError('The datetime attribute of the CalendarDateTime '
275-
'object must be of type `cftime.datetime`.')
282+
if isinstance(first_value, CalendarDateTime):
283+
if not isinstance(first_value.datetime, cftime.datetime):
284+
raise ValueError('The datetime attribute of the '
285+
'CalendarDateTime object must be of type '
286+
'`cftime.datetime`.')
276287

277288
ut = cftime.utime(cls.standard_unit, calendar=first_value.calendar)
278289

279-
if isinstance(value, CalendarDateTime):
290+
if isinstance(value, (CalendarDateTime, cftime.datetime)):
280291
value = [value]
281292

282-
result = ut.date2num([v.datetime for v in value])
293+
if isinstance(first_value, CalendarDateTime):
294+
result = ut.date2num([v.datetime for v in value])
295+
else:
296+
result = ut.date2num(value)
297+
283298
if shape is not None:
284299
result = result.reshape(shape)
285300

@@ -290,3 +305,10 @@ def convert(cls, value, unit, axis):
290305
# dictionary.
291306
if CalendarDateTime not in munits.registry:
292307
munits.registry[CalendarDateTime] = NetCDFTimeConverter()
308+
309+
CFTIME_TYPES = [cftime.DatetimeNoLeap, cftime.DatetimeAllLeap,
310+
cftime.DatetimeProlepticGregorian, cftime.DatetimeGregorian,
311+
cftime.Datetime360Day, cftime.DatetimeJulian]
312+
for date_type in CFTIME_TYPES:
313+
if date_type not in munits.registry:
314+
munits.registry[date_type] = NetCDFTimeConverter()

nc_time_axis/tests/integration/test_plot.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def tearDown(self):
2525
# in an odd state, so we make sure it's been disposed of.
2626
plt.close('all')
2727

28-
def test_360_day_calendar(self):
28+
def test_360_day_calendar_CalendarDateTime(self):
2929
datetimes = [cftime.datetime(1986, month, 30)
3030
for month in range(1, 6)]
3131
cal_datetimes = [nc_time_axis.CalendarDateTime(dt, '360_day')
@@ -34,6 +34,13 @@ def test_360_day_calendar(self):
3434
result_ydata = line1.get_ydata()
3535
np.testing.assert_array_equal(result_ydata, cal_datetimes)
3636

37+
def test_360_day_calendar_raw_dates(self):
38+
datetimes = [cftime.Datetime360Day(1986, month, 30)
39+
for month in range(1, 6)]
40+
line1, = plt.plot(datetimes)
41+
result_ydata = line1.get_ydata()
42+
np.testing.assert_array_equal(result_ydata, datetimes)
43+
3744

3845
if __name__ == "__main__":
3946
unittest.main()

nc_time_axis/tests/unit/test_NetCDFTimeConverter.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
class Test_axisinfo(unittest.TestCase):
1616
def test_axis_default_limits(self):
1717
cal = '360_day'
18-
unit = (cal, 'days since 2000-02-25 00:00:00')
18+
unit = (cal, 'days since 2000-02-25 00:00:00', CalendarDateTime)
1919
result = NetCDFTimeConverter().axisinfo(unit, None)
2020
expected_dt = [cftime.datetime(2000, 1, 1),
2121
cftime.datetime(2010, 1, 1)]
@@ -25,21 +25,21 @@ def test_axis_default_limits(self):
2525

2626

2727
class Test_default_units(unittest.TestCase):
28-
def test_360_day_calendar_point(self):
28+
def test_360_day_calendar_point_CalendarDateTime(self):
2929
calendar = '360_day'
3030
unit = 'days since 2000-01-01'
3131
val = CalendarDateTime(cftime.datetime(2014, 8, 12), calendar)
3232
result = NetCDFTimeConverter().default_units(val, None)
33-
self.assertEqual(result, (calendar, unit))
33+
self.assertEqual(result, (calendar, unit, CalendarDateTime))
3434

35-
def test_360_day_calendar_list(self):
35+
def test_360_day_calendar_list_CalendarDateTime(self):
3636
calendar = '360_day'
3737
unit = 'days since 2000-01-01'
3838
val = [CalendarDateTime(cftime.datetime(2014, 8, 12), calendar)]
3939
result = NetCDFTimeConverter().default_units(val, None)
40-
self.assertEqual(result, (calendar, unit))
40+
self.assertEqual(result, (calendar, unit, CalendarDateTime))
4141

42-
def test_360_day_calendar_nd(self):
42+
def test_360_day_calendar_nd_CalendarDateTime(self):
4343
# Test the case where the input is an nd-array.
4444
calendar = '360_day'
4545
unit = 'days since 2000-01-01'
@@ -48,7 +48,30 @@ def test_360_day_calendar_nd(self):
4848
[CalendarDateTime(cftime.datetime(2014, 8, 13),
4949
calendar)]])
5050
result = NetCDFTimeConverter().default_units(val, None)
51-
self.assertEqual(result, (calendar, unit))
51+
self.assertEqual(result, (calendar, unit, CalendarDateTime))
52+
53+
def test_360_day_calendar_point_raw_date(self):
54+
calendar = '360_day'
55+
unit = 'days since 2000-01-01'
56+
val = cftime.Datetime360Day(2014, 8, 12)
57+
result = NetCDFTimeConverter().default_units(val, None)
58+
self.assertEqual(result, (calendar, unit, cftime.Datetime360Day))
59+
60+
def test_360_day_calendar_list_raw_date(self):
61+
calendar = '360_day'
62+
unit = 'days since 2000-01-01'
63+
val = [cftime.Datetime360Day(2014, 8, 12)]
64+
result = NetCDFTimeConverter().default_units(val, None)
65+
self.assertEqual(result, (calendar, unit, cftime.Datetime360Day))
66+
67+
def test_360_day_calendar_nd_raw_date(self):
68+
# Test the case where the input is an nd-array.
69+
calendar = '360_day'
70+
unit = 'days since 2000-01-01'
71+
val = np.array([[cftime.Datetime360Day(2014, 8, 12)],
72+
[cftime.Datetime360Day(2014, 8, 13)]])
73+
result = NetCDFTimeConverter().default_units(val, None)
74+
self.assertEqual(result, (calendar, unit, cftime.Datetime360Day))
5275

5376
def test_nonequal_calendars(self):
5477
# Test that different supplied calendars causes an error.
@@ -84,17 +107,27 @@ def test_numeric_iterable(self):
84107
result = NetCDFTimeConverter().convert(val, None, None)
85108
np.testing.assert_array_equal(result, val)
86109

87-
def test_cftime(self):
110+
def test_cftime_CalendarDateTime(self):
88111
val = CalendarDateTime(cftime.datetime(2014, 8, 12), '365_day')
89112
result = NetCDFTimeConverter().convert(val, None, None)
90113
np.testing.assert_array_equal(result, 5333.)
91114

92-
def test_cftime_np_array(self):
115+
def test_cftime_raw_date(self):
116+
val = cftime.DatetimeNoLeap(2014, 8, 12)
117+
result = NetCDFTimeConverter().convert(val, None, None)
118+
np.testing.assert_array_equal(result, 5333.)
119+
120+
def test_cftime_np_array_CalendarDateTime(self):
93121
val = np.array([CalendarDateTime(cftime.datetime(2012, 6, 4),
94122
'360_day')], dtype=np.object)
95123
result = NetCDFTimeConverter().convert(val, None, None)
96124
self.assertEqual(result, np.array([4473.]))
97125

126+
def test_cftime_np_array_raw_date(self):
127+
val = np.array([cftime.Datetime360Day(2012, 6, 4)], dtype=np.object)
128+
result = NetCDFTimeConverter().convert(val, None, None)
129+
self.assertEqual(result, np.array([4473.]))
130+
98131
def test_non_cftime_datetime(self):
99132
val = CalendarDateTime(4, '360_day')
100133
msg = 'The datetime attribute of the CalendarDateTime object must ' \
@@ -103,7 +136,7 @@ def test_non_cftime_datetime(self):
103136
result = NetCDFTimeConverter().convert(val, None, None)
104137

105138
def test_non_CalendarDateTime(self):
106-
val = cftime.datetime(1988, 5, 6)
139+
val = 'test'
107140
msg = 'The values must be numbers or instances of ' \
108141
'"nc_time_axis.CalendarDateTime".'
109142
with assertRaisesRegex(self, ValueError, msg):

0 commit comments

Comments
 (0)