Skip to content

Commit c5f6f8d

Browse files
committed
add where + other tensor ops
1 parent 536b7b6 commit c5f6f8d

File tree

3 files changed

+181
-27
lines changed

3 files changed

+181
-27
lines changed

neunet/__init__.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,33 @@ def swapaxes(x, axis1, axis2):
193193

194194
def flip(x, axis):
195195
return x.flip(axis=axis)
196+
197+
def where(condition, x, y):
198+
return x.where(condition, y)
199+
200+
def equal(x, y):
201+
return x.equal(y)
202+
203+
def not_equal(x, y):
204+
return x.not_equal(y)
205+
206+
def greater(x, y):
207+
return x.greater(y)
208+
209+
def greater_equal(x, y):
210+
return x.greater_equal(y)
211+
212+
def less(x, y):
213+
return x.less(y)
214+
215+
def less_equal(x, y):
216+
return x.less_equal(y)
217+
218+
def logical_and(x, y):
219+
return x.logical_and(y)
220+
221+
def logical_or(x, y):
222+
return x.logical_or(y)
223+
224+
def logical_not(x):
225+
return x.logical_not()

neunet/autograd.py

Lines changed: 147 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import cupy as cp
44
import numpy as np
55

6+
# TODO: Add requires_grad condition to args
67

78
class Tensor:
89
def __init__(self, data: Any, args=None, op=None, requires_grad: bool=True, dtype = None, device: str="cpu"):
@@ -14,9 +15,9 @@ def __init__(self, data: Any, args=None, op=None, requires_grad: bool=True, dtyp
1415
self.xp = cp
1516

1617
if isinstance(data, Tensor):
17-
self.data: Union[np.ndarray, cp.ndarray] = self.xp.array(data.data, dtype=dtype)
18+
self.data: Union[np.ndarray, cp.ndarray] = self.xp.array(data.data, dtype=dtype if dtype else np.float32)
1819
else:
19-
self.data = self.xp.array(data, dtype=dtype)
20+
self.data = self.xp.array(data, dtype=dtype if dtype else np.float32)
2021

2122
self.grad = None
2223
self.op = op
@@ -355,6 +356,145 @@ def flip(self, axis: Any) -> 'Tensor':
355356
device=self.device,
356357
)
357358

359+
def where(self, condition: Union[Any, 'Tensor'], t: Union[Any, 'Tensor']) -> 'Tensor':
360+
condition = self.tensor(condition)
361+
t = self.tensor(t)
362+
363+
requires_grad = self.requires_grad or t.requires_grad
364+
args = [self, condition, t] if requires_grad else None
365+
366+
return Tensor(
367+
np.where(condition.data, self.data, t.data),
368+
args,
369+
"where",
370+
requires_grad=requires_grad,
371+
device=self.device,
372+
)
373+
374+
def equal(self, t: Union[Any, 'Tensor']) -> 'Tensor':
375+
t = self.tensor(t)
376+
377+
return Tensor(
378+
self.xp.equal(self.data, t.data),
379+
None,
380+
"equal",
381+
requires_grad=False,
382+
device=self.device,
383+
)
384+
385+
def not_equal(self, t: Union[Any, 'Tensor']) -> 'Tensor':
386+
t = self.tensor(t)
387+
388+
return Tensor(
389+
self.xp.not_equal(self.data, t.data),
390+
None,
391+
"not_equal",
392+
requires_grad=False,
393+
device=self.device,
394+
)
395+
396+
def greater(self, t: Union[Any, 'Tensor']) -> 'Tensor':
397+
t = self.tensor(t)
398+
399+
return Tensor(
400+
self.xp.greater(self.data, t.data),
401+
None,
402+
"greater",
403+
requires_grad=False,
404+
device=self.device,
405+
)
406+
407+
def greater_equal(self, t: Union[Any, 'Tensor']) -> 'Tensor':
408+
t = self.tensor(t)
409+
410+
return Tensor(
411+
self.xp.greater_equal(self.data, t.data),
412+
None,
413+
"greater_equal",
414+
requires_grad=False,
415+
device=self.device,
416+
)
417+
418+
def less(self, t: Union[Any, 'Tensor']) -> 'Tensor':
419+
t = self.tensor(t)
420+
421+
return Tensor(
422+
self.xp.less(self.data, t.data),
423+
None,
424+
"less",
425+
requires_grad=False,
426+
device=self.device,
427+
)
428+
429+
def less_equal(self, t: Union[Any, 'Tensor']) -> 'Tensor':
430+
t = self.tensor(t)
431+
432+
return Tensor(
433+
self.xp.less_equal(self.data, t.data),
434+
None,
435+
"less_equal",
436+
requires_grad=False,
437+
device=self.device,
438+
)
439+
440+
def logical_and(self, t: Union[Any, 'Tensor']) -> 'Tensor':
441+
t = self.tensor(t)
442+
443+
return Tensor(
444+
self.xp.logical_and(self.data, t.data),
445+
None,
446+
"logical_and",
447+
requires_grad=False,
448+
device=self.device,
449+
)
450+
451+
def logical_or(self, t: Union[Any, 'Tensor']) -> 'Tensor':
452+
t = self.tensor(t)
453+
454+
return Tensor(
455+
self.xp.logical_or(self.data, t.data),
456+
None,
457+
"logical_or",
458+
requires_grad=False,
459+
device=self.device,
460+
)
461+
462+
def logical_not(self) -> 'Tensor':
463+
return Tensor(
464+
self.xp.logical_not(self.data),
465+
None,
466+
"logical_not",
467+
requires_grad=False,
468+
device=self.device,
469+
)
470+
471+
def __eq__(self, t: Union[Any, 'Tensor']) -> 'Tensor': # type: ignore[override]
472+
return self.equal(t)
473+
474+
def __ne__(self, t: Union[Any, 'Tensor']) -> 'Tensor': # type: ignore[override]
475+
return self.not_equal(t)
476+
477+
def __gt__(self, t: Union[Any, 'Tensor']) -> 'Tensor':
478+
return self.greater(t)
479+
480+
def __ge__(self, t: Union[Any, 'Tensor']) -> 'Tensor':
481+
return self.greater_equal(t)
482+
483+
def __lt__(self, t: Union[Any, 'Tensor']) -> 'Tensor':
484+
return self.less(t)
485+
486+
def __le__(self, t: Union[Any, 'Tensor']) -> 'Tensor':
487+
return self.less_equal(t)
488+
489+
def __and__(self, t: Union[Any, 'Tensor']) -> 'Tensor':
490+
return self.logical_and(t)
491+
492+
def __or__(self, t: Union[Any, 'Tensor']) -> 'Tensor':
493+
return self.logical_or(t)
494+
495+
def __invert__(self) -> 'Tensor':
496+
return self.logical_not()
497+
358498
def __neg__(self) -> 'Tensor':
359499
return Tensor(
360500
-self.data,
@@ -667,6 +807,10 @@ def backward(
667807

668808
elif self.op == "flip":
669809
self.args[0].backward(self.xp.flip(grad, axis=self.args[1]))
810+
811+
elif self.op == "where":
812+
self.args[0].backward(grad * self.xp.where(self.args[1].data, grad, self.xp.zeros_like(grad)))
813+
self.args[2].backward(grad * self.xp.where(self.args[1].data, self.xp.zeros_like(grad), grad))
670814

671815
elif self.op == "neg":
672816
self.args[0].backward(-grad)
@@ -680,27 +824,4 @@ def backward(
680824

681825
# BUGS:
682826
# grad X - mean not correct with pytorch; maybe NOT BUG becase small numbers manipulation (Numerical stability issues)
683-
# softmax not equals grads with pytorch; place: div; maybe NOT BUG becase small numbers manipulation (Numerical stability issues)????
684-
685-
686-
# def repeat_to_match_shape(self, g, shape, dtype, axis, keepdims): same
687-
# https://github.com/HIPS/autograd/blob/master/autograd/numpy/numpy_vjps.py
688-
# """Returns the array g repeated along axis to fit vector space vs.
689-
# Also returns the number of repetitions of the array."""
690-
# if shape == ():
691-
# return g, 1
692-
# axis = list(axis) if isinstance(axis, tuple) else axis
693-
# new_shape = self.xp.array(shape)
694-
# new_shape[axis] = 1
695-
# num_reps = self.xp.prod(self.xp.array(shape)[axis])
696-
# # Can't use broadcast_to because of numpy bug: https://github.com/numpy/numpy/issues/9165
697-
# # return aself.xp.broadcast_to(aself.xp.reshape(g, new_shape), shape), num_reps
698-
# return self.xp.reshape(g, new_shape) + self.xp.zeros(shape, dtype=dtype), num_reps
699-
700-
# elif self.op == "mean":
701-
# shape = self.args[0].data.shape
702-
# axis = self.args[1]
703-
# dtype = self.xp.result_type(self.args[0].data)
704-
# g_repeated, num_reps = self.repeat_to_match_shape(grad, shape, dtype, axis, None)
705-
# print(f"g_repeated {g_repeated}")
706-
# self.args[0].backward(g_repeated / num_reps)
827+
# softmax not equals grads with pytorch; place: div; maybe NOT BUG becase small numbers manipulation (Numerical stability issues)????

neunet/nn/modules.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,17 @@ def backward(self, *args, **kwargs):
2121

2222
def parameters(self):
2323
params = []
24+
seen_ids = set()
2425
for _, item in self.__dict__.items():
2526
if isinstance(item, Tensor):
27+
item_id = id(item)
2628
if (
2729
item.requires_grad
2830
and item.__class__.__name__ == "Parameter"
29-
and item not in params
31+
and item_id not in seen_ids
3032
):
3133
params.append(item)
34+
seen_ids.add(item_id)
3235
if hasattr(item, "parameters"):
3336
params.extend(item.parameters())
3437

0 commit comments

Comments
 (0)