We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 06304ad commit 279ac75Copy full SHA for 279ac75
python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py
@@ -20,6 +20,7 @@
20
from paddle import _legacy_C_ops
21
from paddle.distributed.parallel import _split_tensors
22
from paddle.fluid import core
23
+from paddle.framework import base as imperative_base
24
25
__all__ = []
26
@@ -165,6 +166,7 @@ def add_grad(self, param):
165
166
if self._all_params_checked_in:
167
self._fused_allreduce_grads()
168
169
+ @imperative_base.no_grad
170
def _fused_allreduce_grads(self):
171
assert self._all_params_checked_in
172
flattened_vars = []
@@ -188,6 +190,7 @@ def _fused_allreduce_grads(self):
188
190
)
189
191
192
193
194
def scale_and_split_grads(self):
195
for task in self._tasks:
196
task.wait()
0 commit comments