File tree Expand file tree Collapse file tree 2 files changed +36
-7
lines changed Expand file tree Collapse file tree 2 files changed +36
-7
lines changed Original file line number Diff line number Diff line change @@ -327,6 +327,10 @@ def __getitem__(self, key):
327
327
else :
328
328
return super (ArrayRDD , self ).__getitem__ (key )
329
329
330
+ @property
331
+ def ndim (self ):
332
+ return self ._rdd .first ().ndim
333
+
330
334
@property
331
335
def shape (self ):
332
336
"""Returns the shape of the data."""
@@ -335,6 +339,12 @@ def shape(self):
335
339
shape = self ._rdd .map (lambda x : x .shape [0 ]).sum ()
336
340
return (shape ,) + first [1 :]
337
341
342
+ @property
343
+ def size (self ):
344
+ """Returns the shape of the data.
345
+ """
346
+ return np .prod (self .shape )
347
+
338
348
def toarray (self ):
339
349
"""Returns the data as numpy.array from each partition."""
340
350
return np .concatenate (self .collect ())
@@ -564,12 +574,6 @@ def unblock(self):
564
574
"""
565
575
return self ._rdd .flatMap (lambda cols : zip (* cols ))
566
576
567
- @property
568
- def shape (self ):
569
- """Returns the shape of the data.
570
- """
571
- return (super (DictRDD , self ).get (0 ).shape [0 ], self .columns )
572
-
573
577
def transform (self , fn , column = None ):
574
578
"""Execute a transformation on a column or columns. Returns the modified
575
579
DictRDD.
Original file line number Diff line number Diff line change @@ -287,16 +287,41 @@ def test_blocks_size(self):
287
287
shapes = ArrayRDD (rdd , 66 ).map (lambda x : x .shape [0 ]).collect ()
288
288
assert_true (all (np .in1d (shapes , [66 , 34 ])))
289
289
290
+ def test_ndim (self ):
291
+ data = np .arange (4000 )
292
+ shapes = [(4000 ),
293
+ (1000 , 4 ),
294
+ (200 , 10 , 2 ),
295
+ (100 , 10 , 2 , 2 )]
296
+ for shape in shapes :
297
+ reshaped = data .reshape (shape )
298
+ rdd = self .sc .parallelize (reshaped )
299
+ assert_equal (ArrayRDD (rdd ).ndim , reshaped .ndim )
300
+
290
301
def test_shape (self ):
291
302
data = np .arange (4000 )
292
303
shapes = [(1000 , 4 ),
293
304
(200 , 20 ),
294
305
(100 , 40 ),
295
306
(2000 , 2 )]
296
307
for shape in shapes :
297
- rdd = self .sc .parallelize (data .reshape (shape ))
308
+ reshaped = data .reshape (shape )
309
+ rdd = self .sc .parallelize (reshaped )
298
310
assert_equal (ArrayRDD (rdd ).shape , shape )
299
311
312
+ def test_size (self ):
313
+ data = np .arange (4000 )
314
+ shapes = [(1000 , 4 ),
315
+ (200 , 20 ),
316
+ (100 , 40 ),
317
+ (2000 , 2 )]
318
+ for shape in shapes :
319
+ reshaped = data .reshape (shape )
320
+ rdd = self .sc .parallelize (reshaped )
321
+ size = ArrayRDD (rdd ).map (lambda x : x .size ).sum ()
322
+ assert_equal (size , reshaped .size )
323
+ assert_equal (ArrayRDD (rdd ).size , reshaped .size )
324
+
300
325
def test_unblocking_rdd (self ):
301
326
data = np .arange (400 )
302
327
rdd = self .sc .parallelize (data , 4 )
You can’t perform that action at this time.
0 commit comments