24
24
25
25
26
26
def release_op (op ):
27
- del framework ._dygraph_tracer ()._ops [op ._trace_id ]
27
+ del framework ._dygraph_tracer ()._ops [op ._trace_id ].inputs
28
+ del framework ._dygraph_tracer ()._ops [op ._trace_id ].outputs
29
+ del framework ._dygraph_tracer ()._ops [op ._trace_id ].backward_refs
28
30
29
31
30
32
class Tracer (core .Tracer ):
@@ -38,6 +40,7 @@ def __init__(self, block):
38
40
self ._ops = defaultdict ()
39
41
self ._vars = defaultdict ()
40
42
self ._trace_id = 0
43
+ self ._train_mode = True
41
44
42
45
def trace_var (self , name , var ):
43
46
self ._vars [name ] = var
@@ -46,15 +49,57 @@ def all_parameters(self):
46
49
return list ((item for name , item in six .iteritems (self ._vars )
47
50
if isinstance (item , framework .Parameter )))
48
51
49
- def trace_op (self , op , stop_gradient = False ):
52
+ def trace_op (self , op , inputs , outputs , stop_gradient = False ):
53
+ # TODO(minqiyang): remove this line after we take apart all
54
+ # backward grads and forward variables
55
+ if self ._train_mode :
56
+ op .inputs = inputs
57
+ inps = defaultdict (list )
58
+ for k , vars in six .iteritems (inputs ):
59
+ if isinstance (vars , framework .Variable ):
60
+ inps [k ].append (vars ._ivar )
61
+ elif isinstance (vars , list ) or isinstance (vars , tuple ):
62
+ for var in vars :
63
+ inps [k ].append (var ._ivar )
64
+
65
+ op .outputs = outputs
66
+ outs = defaultdict (list )
67
+ for k , vars in six .iteritems (outputs ):
68
+ if isinstance (vars , framework .Variable ):
69
+ outs [k ].append (vars ._ivar )
70
+ elif isinstance (vars , list ) or isinstance (vars , tuple ):
71
+ for var in vars :
72
+ outs [k ].append (var ._ivar )
73
+ else :
74
+ inps = defaultdict (list )
75
+ for k , vars in six .iteritems (inputs ):
76
+ if isinstance (vars , framework .Variable ):
77
+ op .previous_ops .append (vars .op )
78
+ inps [k ].append (vars ._ivar )
79
+ elif isinstance (vars , list ) or isinstance (vars , tuple ):
80
+ for var in vars :
81
+ op .previous_ops .append (var .op )
82
+ inps [k ].append (var ._ivar )
83
+
84
+ op .outputs = outputs
85
+ outs = defaultdict (list )
86
+ for k , vars in six .iteritems (outputs ):
87
+ if isinstance (vars , framework .Variable ):
88
+ vars .op = op
89
+ outs [k ].append (vars ._ivar )
90
+ elif isinstance (vars , list ) or isinstance (vars , tuple ):
91
+ for var in vars :
92
+ var .op = op
93
+ outs [k ].append (var ._ivar )
94
+
50
95
# record op's trace id
51
96
op .iop ._trace_id = self ._trace_id
52
97
53
- backward_refs = self .trace (op .iop , op . inputs , op . outputs , op .attrs ,
98
+ backward_refs = self .trace (op .iop , inps , outs , op .attrs ,
54
99
framework ._current_expected_place (),
55
100
stop_gradient )
56
101
57
- if not stop_gradient :
102
+ if not stop_gradient and self . _train_mode :
58
103
self ._trace_id += 1
59
104
self ._ops [op .iop ._trace_id ] = op
60
105
@@ -65,10 +110,16 @@ def trace_op(self, op, stop_gradient=False):
65
110
# TODO(minqiyang): remove all inputs and outputs after separate
66
111
# var and grad
67
112
op .backward_refs = defaultdict (list )
68
- for k , v in six .iteritems (op . inputs ):
113
+ for k , v in six .iteritems (inputs ):
69
114
if k in backward_refs :
70
- op .backward_refs [k ] = op . inputs [k ]
115
+ op .backward_refs [k ] = inputs [k ]
71
116
72
- for k , v in six .iteritems (op . outputs ):
117
+ for k , v in six .iteritems (outputs ):
73
118
if k in backward_refs :
74
- op .backward_refs [k ] = op .outputs [k ]
119
+ op .backward_refs [k ] = outputs [k ]
120
+
121
+ def _train_mode (self ):
122
+ self ._train_mode = True
123
+
124
+ def _eval_mode (self ):
125
+ self ._train_mode = False
0 commit comments