Skip to content

Commit d64de36

Browse files
authored
Merge pull request #21 from dkillick/nd_sample_point
Add convert support for nd arrays
2 parents 1021ea3 + b4f1e1c commit d64de36

File tree

2 files changed

+62
-14
lines changed

2 files changed

+62
-14
lines changed

nc_time_axis/__init__.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -226,17 +226,24 @@ def default_units(cls, sample_point, axis):
226226
Computes some units for the given data point.
227227
228228
"""
229-
try:
230-
# Try getting the first item. Otherwise we just use this item.
231-
sample_point = sample_point[0]
232-
except (TypeError, IndexError):
233-
pass
234-
235-
if not hasattr(sample_point, 'calendar'):
236-
msg = 'Expecting netcdftimes with an extra "calendar" attribute.'
237-
raise ValueError(msg)
238-
239-
return sample_point.calendar, cls.standard_unit
229+
if hasattr(sample_point, '__iter__'):
230+
# Deal with nD `sample_point` arrays.
231+
if isinstance(sample_point, np.ndarray):
232+
sample_point = sample_point.reshape(-1)
233+
calendars = np.array([point.calendar for point in sample_point])
234+
if np.all(calendars[0] == calendars):
235+
calendar = calendars[0]
236+
else:
237+
raise ValueError('Calendar units are not all equal.')
238+
else:
239+
# Deal with a single `sample_point` value.
240+
if not hasattr(sample_point, 'calendar'):
241+
msg = ('Expecting netcdftimes with an extra '
242+
'"calendar" attribute.')
243+
raise ValueError(msg)
244+
else:
245+
calendar = sample_point.calendar
246+
return calendar, cls.standard_unit
240247

241248
@classmethod
242249
def convert(cls, value, unit, axis):
@@ -245,11 +252,13 @@ def convert(cls, value, unit, axis):
245252
with :func:`netcdftime.utime().date2num`.
246253
247254
"""
255+
shape = None
248256
if isinstance(value, np.ndarray):
249257
# Don't do anything with numeric types.
250258
if value.dtype != np.object:
251259
return value
252-
260+
shape = value.shape
261+
value = value.reshape(-1)
253262
first_value = value[0]
254263
else:
255264
# Don't do anything with numeric types.
@@ -270,7 +279,11 @@ def convert(cls, value, unit, axis):
270279
if isinstance(value, CalendarDateTime):
271280
value = [value]
272281

273-
return ut.date2num([v.datetime for v in value])
282+
result = ut.date2num([v.datetime for v in value])
283+
if shape is not None:
284+
result = result.reshape(shape)
285+
286+
return result
274287

275288

276289
# Automatically register NetCDFTimeConverter with matplotlib.unit's converter

nc_time_axis/tests/unit/test_NetCDFTimeConverter.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,55 @@ def test_axis_default_limits(self):
2424

2525

2626
class Test_default_units(unittest.TestCase):
27-
def test_360_day_calendar(self):
27+
def test_360_day_calendar_point(self):
28+
calendar = '360_day'
29+
unit = 'days since 2000-01-01'
30+
val = CalendarDateTime(netcdftime.datetime(2014, 8, 12), calendar)
31+
result = NetCDFTimeConverter().default_units(val, None)
32+
self.assertEqual(result, (calendar, unit))
33+
34+
def test_360_day_calendar_list(self):
2835
calendar = '360_day'
2936
unit = 'days since 2000-01-01'
3037
val = [CalendarDateTime(netcdftime.datetime(2014, 8, 12), calendar)]
3138
result = NetCDFTimeConverter().default_units(val, None)
3239
self.assertEqual(result, (calendar, unit))
3340

41+
def test_360_day_calendar_nd(self):
42+
# Test the case where the input is an nd-array.
43+
calendar = '360_day'
44+
unit = 'days since 2000-01-01'
45+
val = np.array([[CalendarDateTime(netcdftime.datetime(2014, 8, 12),
46+
calendar)],
47+
[CalendarDateTime(netcdftime.datetime(2014, 8, 13),
48+
calendar)]])
49+
result = NetCDFTimeConverter().default_units(val, None)
50+
self.assertEqual(result, (calendar, unit))
51+
52+
def test_nonequal_calendars(self):
53+
# Test that different supplied calendars causes an error.
54+
calendar_1 = '360_day'
55+
calendar_2 = '365_day'
56+
unit = 'days since 2000-01-01'
57+
val = [CalendarDateTime(netcdftime.datetime(2014, 8, 12), calendar_1),
58+
CalendarDateTime(netcdftime.datetime(2014, 8, 13), calendar_2)]
59+
with self.assertRaisesRegexp(ValueError, 'not all equal'):
60+
NetCDFTimeConverter().default_units(val, None)
61+
3462

3563
class Test_convert(unittest.TestCase):
3664
def test_numpy_array(self):
3765
val = np.array([7])
3866
result = NetCDFTimeConverter().convert(val, None, None)
3967
np.testing.assert_array_equal(result, val)
4068

69+
def test_numpy_nd_array(self):
70+
shape = (4, 2)
71+
val = np.arange(8).reshape(shape)
72+
result = NetCDFTimeConverter().convert(val, None, None)
73+
np.testing.assert_array_equal(result, val)
74+
self.assertEqual(result.shape, shape)
75+
4176
def test_numeric(self):
4277
val = 4
4378
result = NetCDFTimeConverter().convert(val, None, None)

0 commit comments

Comments
 (0)