Skip to content

Commit 9e05d64

Browse files
committed
Merge pull request #27 from lensacom/ndim_size
Added ndim and size properties to ArrayRDD; resolves #13, #14
2 parents 91b5cf3 + 5723b6c commit 9e05d64

File tree

2 files changed

+36
-7
lines changed

2 files changed

+36
-7
lines changed

splearn/rdd.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,10 @@ def __getitem__(self, key):
327327
else:
328328
return super(ArrayRDD, self).__getitem__(key)
329329

330+
@property
331+
def ndim(self):
332+
return self._rdd.first().ndim
333+
330334
@property
331335
def shape(self):
332336
"""Returns the shape of the data."""
@@ -335,6 +339,12 @@ def shape(self):
335339
shape = self._rdd.map(lambda x: x.shape[0]).sum()
336340
return (shape,) + first[1:]
337341

342+
@property
343+
def size(self):
344+
"""Returns the shape of the data.
345+
"""
346+
return np.prod(self.shape)
347+
338348
def toarray(self):
339349
"""Returns the data as numpy.array from each partition."""
340350
return np.concatenate(self.collect())
@@ -564,12 +574,6 @@ def unblock(self):
564574
"""
565575
return self._rdd.flatMap(lambda cols: zip(*cols))
566576

567-
@property
568-
def shape(self):
569-
"""Returns the shape of the data.
570-
"""
571-
return (super(DictRDD, self).get(0).shape[0], self.columns)
572-
573577
def transform(self, fn, column=None):
574578
"""Execute a transformation on a column or columns. Returns the modified
575579
DictRDD.

splearn/tests/test_rdd.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,16 +287,41 @@ def test_blocks_size(self):
287287
shapes = ArrayRDD(rdd, 66).map(lambda x: x.shape[0]).collect()
288288
assert_true(all(np.in1d(shapes, [66, 34])))
289289

290+
def test_ndim(self):
291+
data = np.arange(4000)
292+
shapes = [(4000),
293+
(1000, 4),
294+
(200, 10, 2),
295+
(100, 10, 2, 2)]
296+
for shape in shapes:
297+
reshaped = data.reshape(shape)
298+
rdd = self.sc.parallelize(reshaped)
299+
assert_equal(ArrayRDD(rdd).ndim, reshaped.ndim)
300+
290301
def test_shape(self):
291302
data = np.arange(4000)
292303
shapes = [(1000, 4),
293304
(200, 20),
294305
(100, 40),
295306
(2000, 2)]
296307
for shape in shapes:
297-
rdd = self.sc.parallelize(data.reshape(shape))
308+
reshaped = data.reshape(shape)
309+
rdd = self.sc.parallelize(reshaped)
298310
assert_equal(ArrayRDD(rdd).shape, shape)
299311

312+
def test_size(self):
313+
data = np.arange(4000)
314+
shapes = [(1000, 4),
315+
(200, 20),
316+
(100, 40),
317+
(2000, 2)]
318+
for shape in shapes:
319+
reshaped = data.reshape(shape)
320+
rdd = self.sc.parallelize(reshaped)
321+
size = ArrayRDD(rdd).map(lambda x: x.size).sum()
322+
assert_equal(size, reshaped.size)
323+
assert_equal(ArrayRDD(rdd).size, reshaped.size)
324+
300325
def test_unblocking_rdd(self):
301326
data = np.arange(400)
302327
rdd = self.sc.parallelize(data, 4)

0 commit comments

Comments
 (0)