3
3
import cupy as cp
4
4
import numpy as np
5
5
6
+ # TODO: Add requires_grad condition to args
6
7
7
8
class Tensor :
8
9
def __init__ (self , data : Any , args = None , op = None , requires_grad : bool = True , dtype = None , device : str = "cpu" ):
@@ -14,9 +15,9 @@ def __init__(self, data: Any, args=None, op=None, requires_grad: bool=True, dtyp
14
15
self .xp = cp
15
16
16
17
if isinstance (data , Tensor ):
17
- self .data : Union [np .ndarray , cp .ndarray ] = self .xp .array (data .data , dtype = dtype )
18
+ self .data : Union [np .ndarray , cp .ndarray ] = self .xp .array (data .data , dtype = dtype if dtype else np . float32 )
18
19
else :
19
- self .data = self .xp .array (data , dtype = dtype )
20
+ self .data = self .xp .array (data , dtype = dtype if dtype else np . float32 )
20
21
21
22
self .grad = None
22
23
self .op = op
@@ -355,6 +356,145 @@ def flip(self, axis: Any) -> 'Tensor':
355
356
device = self .device ,
356
357
)
357
358
359
+ def where (self , condition : Union [Any , 'Tensor' ], t : Union [Any , 'Tensor' ]) -> 'Tensor' :
360
+ condition = self .tensor (condition )
361
+ t = self .tensor (t )
362
+
363
+ requires_grad = self .requires_grad or t .requires_grad
364
+ args = [self , condition , t ] if requires_grad else None
365
+
366
+ return Tensor (
367
+ np .where (condition .data , self .data , t .data ),
368
+ args ,
369
+ "where" ,
370
+ requires_grad = requires_grad ,
371
+ device = self .device ,
372
+ )
373
+
374
+ def equal (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
375
+ t = self .tensor (t )
376
+
377
+ return Tensor (
378
+ self .xp .equal (self .data , t .data ),
379
+ None ,
380
+ "equal" ,
381
+ requires_grad = False ,
382
+ device = self .device ,
383
+ )
384
+
385
+ def not_equal (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
386
+ t = self .tensor (t )
387
+
388
+ return Tensor (
389
+ self .xp .not_equal (self .data , t .data ),
390
+ None ,
391
+ "not_equal" ,
392
+ requires_grad = False ,
393
+ device = self .device ,
394
+ )
395
+
396
+ def greater (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
397
+ t = self .tensor (t )
398
+
399
+ return Tensor (
400
+ self .xp .greater (self .data , t .data ),
401
+ None ,
402
+ "greater" ,
403
+ requires_grad = False ,
404
+ device = self .device ,
405
+ )
406
+
407
+ def greater_equal (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
408
+ t = self .tensor (t )
409
+
410
+ return Tensor (
411
+ self .xp .greater_equal (self .data , t .data ),
412
+ None ,
413
+ "greater_equal" ,
414
+ requires_grad = False ,
415
+ device = self .device ,
416
+ )
417
+
418
+ def less (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
419
+ t = self .tensor (t )
420
+
421
+ return Tensor (
422
+ self .xp .less (self .data , t .data ),
423
+ None ,
424
+ "less" ,
425
+ requires_grad = False ,
426
+ device = self .device ,
427
+ )
428
+
429
+ def less_equal (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
430
+ t = self .tensor (t )
431
+
432
+ return Tensor (
433
+ self .xp .less_equal (self .data , t .data ),
434
+ None ,
435
+ "less_equal" ,
436
+ requires_grad = False ,
437
+ device = self .device ,
438
+ )
439
+
440
+ def logical_and (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
441
+ t = self .tensor (t )
442
+
443
+ return Tensor (
444
+ self .xp .logical_and (self .data , t .data ),
445
+ None ,
446
+ "logical_and" ,
447
+ requires_grad = False ,
448
+ device = self .device ,
449
+ )
450
+
451
+ def logical_or (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
452
+ t = self .tensor (t )
453
+
454
+ return Tensor (
455
+ self .xp .logical_or (self .data , t .data ),
456
+ None ,
457
+ "logical_or" ,
458
+ requires_grad = False ,
459
+ device = self .device ,
460
+ )
461
+
462
+ def logical_not (self ) -> 'Tensor' :
463
+ return Tensor (
464
+ self .xp .logical_not (self .data ),
465
+ None ,
466
+ "logical_not" ,
467
+ requires_grad = False ,
468
+ device = self .device ,
469
+ )
470
+
471
+ def __eq__ (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' : # type: ignore[override]
472
+ return self .equal (t )
473
+
474
+ def __ne__ (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' : # type: ignore[override]
475
+ return self .not_equal (t )
476
+
477
+ def __gt__ (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
478
+ return self .greater (t )
479
+
480
+ def __ge__ (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
481
+ return self .greater_equal (t )
482
+
483
+ def __lt__ (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
484
+ return self .less (t )
485
+
486
+ def __le__ (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
487
+ return self .less_equal (t )
488
+
489
+ def __and__ (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
490
+ return self .logical_and (t )
491
+
492
+ def __or__ (self , t : Union [Any , 'Tensor' ]) -> 'Tensor' :
493
+ return self .logical_or (t )
494
+
495
+ def __invert__ (self ) -> 'Tensor' :
496
+ return self .logical_not ()
497
+
358
498
def __neg__ (self ) -> 'Tensor' :
359
499
return Tensor (
360
500
- self .data ,
@@ -667,6 +807,10 @@ def backward(
667
807
668
808
elif self .op == "flip" :
669
809
self .args [0 ].backward (self .xp .flip (grad , axis = self .args [1 ]))
810
+
811
+ elif self .op == "where" :
812
+ self .args [0 ].backward (grad * self .xp .where (self .args [1 ].data , grad , self .xp .zeros_like (grad )))
813
+ self .args [2 ].backward (grad * self .xp .where (self .args [1 ].data , self .xp .zeros_like (grad ), grad ))
670
814
671
815
elif self .op == "neg" :
672
816
self .args [0 ].backward (- grad )
@@ -680,27 +824,4 @@ def backward(
680
824
681
825
# BUGS:
682
826
# grad X - mean not correct with pytorch; maybe NOT BUG becase small numbers manipulation (Numerical stability issues)
683
- # softmax not equals grads with pytorch; place: div; maybe NOT BUG becase small numbers manipulation (Numerical stability issues)????
684
-
685
-
686
- # def repeat_to_match_shape(self, g, shape, dtype, axis, keepdims): same
687
- # https://github.com/HIPS/autograd/blob/master/autograd/numpy/numpy_vjps.py
688
- # """Returns the array g repeated along axis to fit vector space vs.
689
- # Also returns the number of repetitions of the array."""
690
- # if shape == ():
691
- # return g, 1
692
- # axis = list(axis) if isinstance(axis, tuple) else axis
693
- # new_shape = self.xp.array(shape)
694
- # new_shape[axis] = 1
695
- # num_reps = self.xp.prod(self.xp.array(shape)[axis])
696
- # # Can't use broadcast_to because of numpy bug: https://github.com/numpy/numpy/issues/9165
697
- # # return aself.xp.broadcast_to(aself.xp.reshape(g, new_shape), shape), num_reps
698
- # return self.xp.reshape(g, new_shape) + self.xp.zeros(shape, dtype=dtype), num_reps
699
-
700
- # elif self.op == "mean":
701
- # shape = self.args[0].data.shape
702
- # axis = self.args[1]
703
- # dtype = self.xp.result_type(self.args[0].data)
704
- # g_repeated, num_reps = self.repeat_to_match_shape(grad, shape, dtype, axis, None)
705
- # print(f"g_repeated {g_repeated}")
706
- # self.args[0].backward(g_repeated / num_reps)
827
+ # softmax not equals grads with pytorch; place: div; maybe NOT BUG becase small numbers manipulation (Numerical stability issues)????
0 commit comments