From 1e7cac945368abd16ae0b636570a7fb8eaebb786 Mon Sep 17 00:00:00 2001 From: ZhenxingLi Date: Fri, 25 Apr 2025 19:05:54 +0800 Subject: [PATCH 01/11] Update auto_cast.py --- python/paddle/amp/auto_cast.py | 78 ++++++++++++++++++++++++++++++---- 1 file changed, 70 insertions(+), 8 deletions(-) diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index c204775d896046..8ad00b999a9a40 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -14,6 +14,7 @@ from __future__ import annotations import copy +import os import warnings from typing import ( TYPE_CHECKING, @@ -655,6 +656,30 @@ def amp_guard( and not amp_global_state().already_register_final_backward_hook ): + def _dtensor_from_local( + local_tensor, mesh, placements, local_tensor_shape=None + ): + global_dims = list(local_tensor.shape) + if local_tensor_shape is not None: + global_dims = local_tensor_shape + for idx, placement in enumerate(placements): + if placement.is_shard(): + shard_dim = placement.get_dim() + local_dim_size = global_dims[shard_dim] + global_dims[shard_dim] = ( + local_dim_size * mesh.shape[idx] + ) + place = paddle.framework._current_expected_place() + place = paddle.framework._get_paddle_place(place) + + return paddle.Tensor( + local_tensor, + dims=global_dims, + process_mesh=mesh, + placements=placements, + place=place, + ) + def master_grad_hook(): # NOTE(lizhiyu): To support semi-auto of dygraph mode, we must # classify the params of model into different classes according to their process_mesh. @@ -674,17 +699,54 @@ def master_grad_hook(): param.process_mesh ].append(param) amp_global_state().already_classify_params_meshes = True - - if len(amp_global_state().mesh2params): - for _, params in amp_global_state().mesh2params.items(): - core.eager.set_master_grads(params) - else: - core.eager.set_master_grads( - amp_global_state().model_parameters - ) + enable_inplace_master_grad = ( + os.getenv("FLAGS_enable_inplace_master_grad") == '1' + ) + if not enable_inplace_master_grad: + if len(amp_global_state().mesh2params): + for _, params in amp_global_state().mesh2params.items(): + core.eager.set_master_grads(params) + else: + core.eager.set_master_grads( + amp_global_state().model_parameters + ) amp_global_state().already_register_final_backward_hook = False + def _update_main_grad_hook(param): + @paddle.autograd.no_grad() + def param_hook(tmp_grad): + if tmp_grad is not None and tmp_grad._is_initialized(): + if param.main_grad is None: + tmp = core.eager.Tensor( + value=tmp_grad._local_value() + .cast(paddle.float32) + .value(), + place=tmp_grad.place, + name="main_grad@" + param.name, + ) + param.main_grad = _dtensor_from_local( + tmp, + tmp_grad.process_mesh, + tmp_grad.placements, + ) + else: + param.main_grad._local_value().add_( + tmp_grad._local_value() + ) + tmp_grad._clear_data() + + return param_hook + + enable_inplace_master_grad = ( + os.getenv("FLAGS_enable_inplace_master_grad") == '1' + ) + if enable_inplace_master_grad: + for param in amp_global_state().model_parameters: + if not hasattr(param, "main_grad"): + param.main_grad = None + param._register_grad_hook(_update_main_grad_hook(param)) + core.eager._add_backward_final_hook(master_grad_hook) amp_global_state().already_register_final_backward_hook = True From b279aea39b68fb7ef30cc35b7f2fc330e5df6f00 Mon Sep 17 00:00:00 2001 From: ZhenxingLi Date: Fri, 25 Apr 2025 19:06:45 +0800 Subject: [PATCH 02/11] Update optimizer.py --- python/paddle/optimizer/optimizer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index d2dce1b588d906..3138f59ae3a854 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -2000,9 +2000,16 @@ def step(self) -> None: for param in self._param_groups: if param.stop_gradient: continue - if param._grad_ivar() is not None: - grad_var = param._grad_ivar() - params_grads.append((param, grad_var)) + enable_inplace_master_grad = ( + os.getenv("FLAGS_enable_inplace_master_grad") == '1' + ) + if not enable_inplace_master_grad: + if param._grad_ivar() is not None: + grad_var = param._grad_ivar() + params_grads.append((param, grad_var)) + else: + if param.main_grad is not None: + params_grads.append((param, param.main_grad)) self._apply_optimize( loss=None, From 4fc5812152fa093444cb94fad6eb1e44ef83fc05 Mon Sep 17 00:00:00 2001 From: ZhenxingLi Date: Fri, 25 Apr 2025 19:08:44 +0800 Subject: [PATCH 03/11] Update api.py --- python/paddle/distributed/auto_parallel/api.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index 37051e1fe29978..eea78021d4c59c 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -1253,6 +1253,12 @@ def _create_accumulators(self, block, parameters): def _finish_update(self, block, parameters_and_grads): self._inner_opt._finish_update(block, parameters_and_grads) + enable_inplace_master_grad = ( + os.getenv("FLAGS_enable_inplace_master_grad") == '1' + ) + if enable_inplace_master_grad: + for param, _ in parameters_and_grads: + param.main_grad._local_value().zero_() if isinstance(parameters_and_grads, list): for p, _ in parameters_and_grads: self._reset_placements(p) From a43626a76517610d6390bebca741a4a5e5723ffc Mon Sep 17 00:00:00 2001 From: ZhenxingLi Date: Tue, 29 Apr 2025 15:46:51 +0800 Subject: [PATCH 04/11] Update optimizer.py --- python/paddle/optimizer/optimizer.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 3138f59ae3a854..8673c99a995761 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -2000,16 +2000,16 @@ def step(self) -> None: for param in self._param_groups: if param.stop_gradient: continue - enable_inplace_master_grad = ( - os.getenv("FLAGS_enable_inplace_master_grad") == '1' - ) - if not enable_inplace_master_grad: + if os.getenv("FLAGS_enable_inplace_master_grad") == '1': + if ( + hasattr(param, "main_grad") + and param.main_grad is not None + ): + params_grads.append((param, param.main_grad)) + else: if param._grad_ivar() is not None: grad_var = param._grad_ivar() params_grads.append((param, grad_var)) - else: - if param.main_grad is not None: - params_grads.append((param, param.main_grad)) self._apply_optimize( loss=None, From bfc03c968abfdce1873bd0f4d7740b7da905c16e Mon Sep 17 00:00:00 2001 From: ZhenxingLi Date: Tue, 29 Apr 2025 15:49:30 +0800 Subject: [PATCH 05/11] Update api.py --- python/paddle/distributed/auto_parallel/api.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index eea78021d4c59c..82ba00bc58ad80 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -1128,6 +1128,10 @@ def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1): for param in self._inner_opt._parameter_list: self._shard_fn._shard_parameter(param) + self.enable_inplace_master_grad = ( + os.getenv("FLAGS_enable_inplace_master_grad") == '1' + ) + def _set_and_check_sharding_prop_from_param(self): global_mesh = fleet.auto.get_mesh() if global_mesh: @@ -1253,10 +1257,7 @@ def _create_accumulators(self, block, parameters): def _finish_update(self, block, parameters_and_grads): self._inner_opt._finish_update(block, parameters_and_grads) - enable_inplace_master_grad = ( - os.getenv("FLAGS_enable_inplace_master_grad") == '1' - ) - if enable_inplace_master_grad: + if self.enable_inplace_master_grad: for param, _ in parameters_and_grads: param.main_grad._local_value().zero_() if isinstance(parameters_and_grads, list): From b64e29b3431f3be08f041c9fe1b41eca4cdfce24 Mon Sep 17 00:00:00 2001 From: ZhenxingLi Date: Tue, 29 Apr 2025 15:49:50 +0800 Subject: [PATCH 06/11] Update auto_cast.py --- python/paddle/amp/auto_cast.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index 8ad00b999a9a40..d2b9b13b5304f4 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -699,10 +699,7 @@ def master_grad_hook(): param.process_mesh ].append(param) amp_global_state().already_classify_params_meshes = True - enable_inplace_master_grad = ( - os.getenv("FLAGS_enable_inplace_master_grad") == '1' - ) - if not enable_inplace_master_grad: + if not os.getenv("FLAGS_enable_inplace_master_grad") == '1': if len(amp_global_state().mesh2params): for _, params in amp_global_state().mesh2params.items(): core.eager.set_master_grads(params) @@ -738,10 +735,7 @@ def param_hook(tmp_grad): return param_hook - enable_inplace_master_grad = ( - os.getenv("FLAGS_enable_inplace_master_grad") == '1' - ) - if enable_inplace_master_grad: + if os.getenv("FLAGS_enable_inplace_master_grad") == '1': for param in amp_global_state().model_parameters: if not hasattr(param, "main_grad"): param.main_grad = None From 4e656d0e6911a27222c91d5da502c49d97a56f7b Mon Sep 17 00:00:00 2001 From: ZhenxingLi Date: Tue, 29 Apr 2025 15:57:30 +0800 Subject: [PATCH 07/11] Update api.py From 0459942e60c802bd00d0c3d0df1c13165e469f6d Mon Sep 17 00:00:00 2001 From: ZhenxingLi Date: Fri, 9 May 2025 16:06:50 +0800 Subject: [PATCH 08/11] Update semi_auto_parallel_sharding_stage_1.py --- .../semi_auto_parallel_sharding_stage_1.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/test/auto_parallel/semi_auto_parallel_sharding_stage_1.py b/test/auto_parallel/semi_auto_parallel_sharding_stage_1.py index 061da8a7978f03..e2b66999efcaee 100644 --- a/test/auto_parallel/semi_auto_parallel_sharding_stage_1.py +++ b/test/auto_parallel/semi_auto_parallel_sharding_stage_1.py @@ -173,6 +173,36 @@ def test_sharding_stage_1_overlap_to_static(self): for batch_id, (image, label) in enumerate(dist_loader()): loss = dist_model(image, label) + def test_pure_sharding_multi_mesh_stage_1_with_inplace_master_grad(self): + def run_sharding_test(enable_inplace_master_grad): + os.environ['FLAGS_enable_inplace_master_grad'] = ( + '1' if enable_inplace_master_grad else '0' + ) + paddle.distributed.auto_parallel.set_mesh(self._multi_dim_mesh) + paddle.seed(self._seed) + model = paddle.nn.Linear(10, 10) + batch = paddle.rand(shape=[10, 10]) + batch = dist.shard_tensor(batch, self._mesh, [dist.Shard(0)]) + opt = paddle.optimizer.AdamW(parameters=model.parameters()) + opt = dist.shard_optimizer( + opt, dist.ShardingStage1(sharding_mesh_dim="dp") + ) + model, opt = paddle.amp.decorate( + model, optimizers=opt, level='O2', master_grad=True + ) + for _ in range(5): + with paddle.amp.auto_cast(level='O2'): + loss = model(batch) + loss.backward() + opt.step() + opt.clear_grad() + return loss.numpy() + + dist.init_parallel_env() + loss_disable = run_sharding_test(enable_inplace_master_grad=False) + loss_enable = run_sharding_test(enable_inplace_master_grad=True) + self.check_tensor_eq(loss_disable, loss_enable) + def run_test_case(self): if self._backend == "cpu": paddle.set_device("cpu") @@ -188,6 +218,7 @@ def run_test_case(self): self.test_sharding_stage_1_to_static() self.test_pure_sharding_multi_mesh_stage_1() self.test_sharding_stage_1_overlap_to_static() + self.test_pure_sharding_multi_mesh_stage_1_with_inplace_master_grad() if __name__ == '__main__': From 388364db07da1af63bae58c225926b4ef397c75b Mon Sep 17 00:00:00 2001 From: ZhenxingLi Date: Mon, 12 May 2025 16:53:30 +0800 Subject: [PATCH 09/11] Update auto_cast.py --- python/paddle/amp/auto_cast.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index d2b9b13b5304f4..538393891b1277 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -656,18 +656,12 @@ def amp_guard( and not amp_global_state().already_register_final_backward_hook ): - def _dtensor_from_local( - local_tensor, mesh, placements, local_tensor_shape=None - ): + def _dtensor_from_local(local_tensor, mesh, placements): global_dims = list(local_tensor.shape) - if local_tensor_shape is not None: - global_dims = local_tensor_shape for idx, placement in enumerate(placements): if placement.is_shard(): - shard_dim = placement.get_dim() - local_dim_size = global_dims[shard_dim] - global_dims[shard_dim] = ( - local_dim_size * mesh.shape[idx] + global_dims[placement.get_dim()] = ( + global_dims[placement.get_dim()] * mesh.shape[idx] ) place = paddle.framework._current_expected_place() place = paddle.framework._get_paddle_place(place) From b78d9d0d26356a5c4bee804d30ebd2e667519485 Mon Sep 17 00:00:00 2001 From: ZhenxingLi Date: Mon, 12 May 2025 19:20:50 +0800 Subject: [PATCH 10/11] Update auto_cast.py From ef7b30bbf2a55f06598e54e6ad4c05634210e03f Mon Sep 17 00:00:00 2001 From: Xing-lil Date: Tue, 13 May 2025 11:39:58 +0800 Subject: [PATCH 11/11] del Flags_enable_inplace_master_grad --- python/paddle/amp/auto_cast.py | 4 ++-- python/paddle/distributed/auto_parallel/api.py | 6 +++--- python/paddle/optimizer/optimizer.py | 2 +- .../semi_auto_parallel_sharding_stage_1.py | 14 +++++++------- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index 538393891b1277..c16c1f75e79183 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -693,7 +693,7 @@ def master_grad_hook(): param.process_mesh ].append(param) amp_global_state().already_classify_params_meshes = True - if not os.getenv("FLAGS_enable_inplace_master_grad") == '1': + if not os.getenv("FLAGS_enable_tensor_fusion") == '1': if len(amp_global_state().mesh2params): for _, params in amp_global_state().mesh2params.items(): core.eager.set_master_grads(params) @@ -729,7 +729,7 @@ def param_hook(tmp_grad): return param_hook - if os.getenv("FLAGS_enable_inplace_master_grad") == '1': + if os.getenv("FLAGS_enable_tensor_fusion") == '1': for param in amp_global_state().model_parameters: if not hasattr(param, "main_grad"): param.main_grad = None diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index 82ba00bc58ad80..65dcbf28875876 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -1128,8 +1128,8 @@ def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1): for param in self._inner_opt._parameter_list: self._shard_fn._shard_parameter(param) - self.enable_inplace_master_grad = ( - os.getenv("FLAGS_enable_inplace_master_grad") == '1' + self.enable_tensor_fusion = ( + os.getenv("FLAGS_enable_tensor_fusion") == '1' ) def _set_and_check_sharding_prop_from_param(self): @@ -1257,7 +1257,7 @@ def _create_accumulators(self, block, parameters): def _finish_update(self, block, parameters_and_grads): self._inner_opt._finish_update(block, parameters_and_grads) - if self.enable_inplace_master_grad: + if self.enable_tensor_fusion: for param, _ in parameters_and_grads: param.main_grad._local_value().zero_() if isinstance(parameters_and_grads, list): diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 8673c99a995761..e180d9ba0b1aa8 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -2000,7 +2000,7 @@ def step(self) -> None: for param in self._param_groups: if param.stop_gradient: continue - if os.getenv("FLAGS_enable_inplace_master_grad") == '1': + if os.getenv("FLAGS_enable_tensor_fusion") == '1': if ( hasattr(param, "main_grad") and param.main_grad is not None diff --git a/test/auto_parallel/semi_auto_parallel_sharding_stage_1.py b/test/auto_parallel/semi_auto_parallel_sharding_stage_1.py index e2b66999efcaee..cd2114df69dc82 100644 --- a/test/auto_parallel/semi_auto_parallel_sharding_stage_1.py +++ b/test/auto_parallel/semi_auto_parallel_sharding_stage_1.py @@ -173,10 +173,10 @@ def test_sharding_stage_1_overlap_to_static(self): for batch_id, (image, label) in enumerate(dist_loader()): loss = dist_model(image, label) - def test_pure_sharding_multi_mesh_stage_1_with_inplace_master_grad(self): - def run_sharding_test(enable_inplace_master_grad): - os.environ['FLAGS_enable_inplace_master_grad'] = ( - '1' if enable_inplace_master_grad else '0' + def test_pure_sharding_multi_mesh_stage_1_with_tensor_fusion(self): + def run_sharding_test(enable_tensor_fusion): + os.environ['FLAGS_enable_tensor_fusion'] = ( + '1' if enable_tensor_fusion else '0' ) paddle.distributed.auto_parallel.set_mesh(self._multi_dim_mesh) paddle.seed(self._seed) @@ -199,8 +199,8 @@ def run_sharding_test(enable_inplace_master_grad): return loss.numpy() dist.init_parallel_env() - loss_disable = run_sharding_test(enable_inplace_master_grad=False) - loss_enable = run_sharding_test(enable_inplace_master_grad=True) + loss_disable = run_sharding_test(enable_tensor_fusion=False) + loss_enable = run_sharding_test(enable_tensor_fusion=True) self.check_tensor_eq(loss_disable, loss_enable) def run_test_case(self): @@ -218,7 +218,7 @@ def run_test_case(self): self.test_sharding_stage_1_to_static() self.test_pure_sharding_multi_mesh_stage_1() self.test_sharding_stage_1_overlap_to_static() - self.test_pure_sharding_multi_mesh_stage_1_with_inplace_master_grad() + self.test_pure_sharding_multi_mesh_stage_1_with_tensor_fusion() if __name__ == '__main__':