Skip to content

Commit 279ac75

Browse files
authored
fix split_tensor of dp_pp_comm_overlap (#54310)
1 parent 06304ad commit 279ac75

File tree

1 file changed

+3
-0
lines changed
  • python/paddle/distributed/fleet/meta_parallel/pp_utils

1 file changed

+3
-0
lines changed

python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from paddle import _legacy_C_ops
2121
from paddle.distributed.parallel import _split_tensors
2222
from paddle.fluid import core
23+
from paddle.framework import base as imperative_base
2324

2425
__all__ = []
2526

@@ -165,6 +166,7 @@ def add_grad(self, param):
165166
if self._all_params_checked_in:
166167
self._fused_allreduce_grads()
167168

169+
@imperative_base.no_grad
168170
def _fused_allreduce_grads(self):
169171
assert self._all_params_checked_in
170172
flattened_vars = []
@@ -188,6 +190,7 @@ def _fused_allreduce_grads(self):
188190
)
189191
)
190192

193+
@imperative_base.no_grad
191194
def scale_and_split_grads(self):
192195
for task in self._tasks:
193196
task.wait()

0 commit comments

Comments
 (0)