diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index c34ec8f45e15f3..b9967ca202c80c 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -20,6 +20,7 @@ from paddle import _legacy_C_ops from paddle.distributed.parallel import _split_tensors from paddle.fluid import core +from paddle.framework import base as imperative_base __all__ = [] @@ -165,6 +166,7 @@ def add_grad(self, param): if self._all_params_checked_in: self._fused_allreduce_grads() + @imperative_base.no_grad def _fused_allreduce_grads(self): assert self._all_params_checked_in flattened_vars = [] @@ -188,6 +190,7 @@ def _fused_allreduce_grads(self): ) ) + @imperative_base.no_grad def scale_and_split_grads(self): for task in self._tasks: task.wait()