From 8187ed2b5c17928c085f778eb6d57f08abe4022a Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Sun, 20 Oct 2024 23:34:40 +0800 Subject: [PATCH] add release for color sharding (#68826) --- .../dygraph_optimizer/dygraph_sharding_optimizer.py | 8 +++++--- .../distributed/fleet/utils/tensor_fusion_helper.py | 3 +++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index 8c57f7e67f6519..19efbce9979938 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -810,7 +810,8 @@ def clear_grad_func(p): if self.sd_release_grads and not self.pp_overlap: for comm_buffer in self._comm_buffer_list: - comm_buffer._clear_grad_storage() + if comm_buffer.need_reduce_scale_sync(): + comm_buffer._clear_grad_storage() def filter_parameters(self, parameter_list, hcg): parameter_list = [ @@ -834,8 +835,9 @@ def reduce_gradients(self, parameter_list, hcg): with framework.no_grad(): for comm_buffer in self._comm_buffer_list: if self.sd_release_grads and comm_buffer.grad_storage is None: - for param in comm_buffer.params: - comm_buffer._copy_grad_to_buffer(param) + if comm_buffer.need_reduce_scale_sync(): + for param in comm_buffer.params: + comm_buffer._copy_grad_to_buffer(param) if g_sharding_v2_check_zero_padding: self._check_padding_zero() diff --git a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py index 935eb283259362..eb429e9006d50f 100644 --- a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py +++ b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py @@ -553,6 +553,9 @@ def _copy_grad_to_buffer(self, param): ) grad_var = param.main_grad if self.use_main_grad else param.grad + assert ( + grad_var is not None + ), f"The current parameter[{param.name}] has no gradient, its stop_grdient is {param.stop_gradient}" grad_var.stop_gradient = True grad_var.flatten_()