Skip to content

Commit 192b6d6

Browse files
committed
Untrack op in eval mode
test=release/1.4
1 parent 4914da1 commit 192b6d6

File tree

3 files changed

+89
-35
lines changed

3 files changed

+89
-35
lines changed

python/paddle/fluid/dygraph/layers.py

+12
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ def __init__(self, name_scope, dtype=core.VarDesc.VarType.FP32):
4848

4949
self._helper = LayerObjectHelper(self._full_name)
5050

51+
def train(self):
52+
framework._dygraph_tracer()._train_mode()
53+
54+
def eval(self):
55+
framework._dygraph_tracer()._eval_mode()
56+
5157
def full_name(self):
5258
"""Full name for this layers.
5359
@@ -254,6 +260,12 @@ class PyLayer(core.PyLayer):
254260
def __init__(self):
255261
super(PyLayer, self).__init__()
256262

263+
def train(self):
264+
framework._dygraph_tracer()._train_mode()
265+
266+
def eval(self):
267+
framework._dygraph_tracer()._eval_mode()
268+
257269
@classmethod
258270
def _do_forward(cls, inputs):
259271
return cls._to_tuple(cls.forward(inputs))

python/paddle/fluid/dygraph/tracer.py

+59-8
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424

2525

2626
def release_op(op):
27-
del framework._dygraph_tracer()._ops[op._trace_id]
27+
del framework._dygraph_tracer()._ops[op._trace_id].inputs
28+
del framework._dygraph_tracer()._ops[op._trace_id].outputs
29+
del framework._dygraph_tracer()._ops[op._trace_id].backward_refs
2830

2931

3032
class Tracer(core.Tracer):
@@ -38,6 +40,7 @@ def __init__(self, block):
3840
self._ops = defaultdict()
3941
self._vars = defaultdict()
4042
self._trace_id = 0
43+
self._train_mode = True
4144

4245
def trace_var(self, name, var):
4346
self._vars[name] = var
@@ -46,15 +49,57 @@ def all_parameters(self):
4649
return list((item for name, item in six.iteritems(self._vars)
4750
if isinstance(item, framework.Parameter)))
4851

49-
def trace_op(self, op, stop_gradient=False):
52+
def trace_op(self, op, inputs, outputs, stop_gradient=False):
53+
# TODO(minqiyang): remove this line after we take apart all
54+
# backward grads and forward variables
55+
if self._train_mode:
56+
op.inputs = inputs
57+
inps = defaultdict(list)
58+
for k, vars in six.iteritems(inputs):
59+
if isinstance(vars, framework.Variable):
60+
inps[k].append(vars._ivar)
61+
elif isinstance(vars, list) or isinstance(vars, tuple):
62+
for var in vars:
63+
inps[k].append(var._ivar)
64+
65+
op.outputs = outputs
66+
outs = defaultdict(list)
67+
for k, vars in six.iteritems(outputs):
68+
if isinstance(vars, framework.Variable):
69+
outs[k].append(vars._ivar)
70+
elif isinstance(vars, list) or isinstance(vars, tuple):
71+
for var in vars:
72+
outs[k].append(var._ivar)
73+
else:
74+
inps = defaultdict(list)
75+
for k, vars in six.iteritems(inputs):
76+
if isinstance(vars, framework.Variable):
77+
op.previous_ops.append(vars.op)
78+
inps[k].append(vars._ivar)
79+
elif isinstance(vars, list) or isinstance(vars, tuple):
80+
for var in vars:
81+
op.previous_ops.append(var.op)
82+
inps[k].append(var._ivar)
83+
84+
op.outputs = outputs
85+
outs = defaultdict(list)
86+
for k, vars in six.iteritems(outputs):
87+
if isinstance(vars, framework.Variable):
88+
vars.op = op
89+
outs[k].append(vars._ivar)
90+
elif isinstance(vars, list) or isinstance(vars, tuple):
91+
for var in vars:
92+
var.op = op
93+
outs[k].append(var._ivar)
94+
5095
# record op's trace id
5196
op.iop._trace_id = self._trace_id
5297

53-
backward_refs = self.trace(op.iop, op.inputs, op.outputs, op.attrs,
98+
backward_refs = self.trace(op.iop, inps, outs, op.attrs,
5499
framework._current_expected_place(),
55100
stop_gradient)
56101

57-
if not stop_gradient:
102+
if not stop_gradient and self._train_mode:
58103
self._trace_id += 1
59104
self._ops[op.iop._trace_id] = op
60105

@@ -65,10 +110,16 @@ def trace_op(self, op, stop_gradient=False):
65110
# TODO(minqiyang): remove all inputs and outputs after separate
66111
# var and grad
67112
op.backward_refs = defaultdict(list)
68-
for k, v in six.iteritems(op.inputs):
113+
for k, v in six.iteritems(inputs):
69114
if k in backward_refs:
70-
op.backward_refs[k] = op.inputs[k]
115+
op.backward_refs[k] = inputs[k]
71116

72-
for k, v in six.iteritems(op.outputs):
117+
for k, v in six.iteritems(outputs):
73118
if k in backward_refs:
74-
op.backward_refs[k] = op.outputs[k]
119+
op.backward_refs[k] = outputs[k]
120+
121+
def _train_mode(self):
122+
self._train_mode = True
123+
124+
def _eval_mode(self):
125+
self._train_mode = False

python/paddle/fluid/framework.py

+18-27
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ def __init__(self,
407407
if persistable else False)
408408
if persistable:
409409
_dygraph_tracer().trace_var(name, self)
410+
self.op = None
410411
else:
411412
self.error_clip = error_clip
412413

@@ -935,26 +936,9 @@ def __init__(self,
935936
raise ValueError(
936937
"`type` to initialized an Operator can not be None.")
937938
self.iop = core.OpBase(type)
939+
self.previous_ops = []
938940

939-
# TODO(minqiyang): remove these lines after we take apart all
940-
# backward grads and forward variables
941-
self.inputs = defaultdict(list)
942-
if inputs is not None:
943-
for k, v in six.iteritems(inputs):
944-
if isinstance(v, Variable):
945-
self.inputs[k].append(v._ivar)
946-
elif isinstance(v, list) or isinstance(v, tuple):
947-
self.inputs[k].extend([var._ivar for var in v])
948-
949-
self.outputs = defaultdict(list)
950-
if outputs is not None:
951-
for k, v in six.iteritems(outputs):
952-
if isinstance(v, Variable):
953-
self.outputs[k].append(v._ivar)
954-
elif isinstance(v, list) or isinstance(v, tuple):
955-
self.outputs[k].extend([var._ivar for var in v])
956-
957-
self.attrs = attrs if attrs else {}
941+
self.attrs = attrs
958942
else:
959943
self.block = block
960944
self.desc = desc
@@ -1643,15 +1627,18 @@ def append_op(self, *args, **kwargs):
16431627
block=self,
16441628
desc=None,
16451629
type=kwargs.get("type", None),
1646-
inputs=kwargs.get("inputs", None),
1647-
outputs=kwargs.get("outputs", None),
1648-
attrs=kwargs.get("attrs", None))
1630+
inputs=None,
1631+
outputs=None,
1632+
attrs=kwargs.get("attrs", {}))
16491633

16501634
# record ops in tracer rather than blocks
16511635
#
16521636
# TODO(minqiyang): add op stop_gradient support in static mode too.
16531637
# currently, we only support stop_gradient in dygraph mode.
1654-
_dygraph_tracer().trace_op(op, kwargs.get("stop_gradient", False))
1638+
_dygraph_tracer().trace_op(op,
1639+
kwargs.get("inputs", {}),
1640+
kwargs.get("outputs", {}),
1641+
kwargs.get("stop_gradient", False))
16551642
else:
16561643
op_desc = self.desc.append_op()
16571644
op = Operator(
@@ -1715,10 +1702,14 @@ def _prepend_op(self, *args, **kwargs):
17151702
self,
17161703
None,
17171704
type=kwargs.get("type", None),
1718-
inputs=kwargs.get("inputs", None),
1719-
outputs=kwargs.get("outputs", None),
1720-
attrs=kwargs.get("attrs", None))
1721-
_dygraph_tracer().trace_op(op, kwargs.get("stop_gradient", False))
1705+
inputs=None,
1706+
outputs=None,
1707+
attrs=kwargs.get("attrs", {}))
1708+
1709+
_dygraph_tracer().trace_op(op,
1710+
kwargs.get("inputs", {}),
1711+
kwargs.get("outputs", {}),
1712+
kwargs.get("stop_gradient", False))
17221713
else:
17231714
op_desc = self.desc._prepend_op()
17241715
op = Operator(

0 commit comments

Comments
 (0)