diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index c204775d89604..c16c1f75e7918 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -14,6 +14,7 @@ from __future__ import annotations import copy +import os import warnings from typing import ( TYPE_CHECKING, @@ -655,6 +656,24 @@ def amp_guard( and not amp_global_state().already_register_final_backward_hook ): + def _dtensor_from_local(local_tensor, mesh, placements): + global_dims = list(local_tensor.shape) + for idx, placement in enumerate(placements): + if placement.is_shard(): + global_dims[placement.get_dim()] = ( + global_dims[placement.get_dim()] * mesh.shape[idx] + ) + place = paddle.framework._current_expected_place() + place = paddle.framework._get_paddle_place(place) + + return paddle.Tensor( + local_tensor, + dims=global_dims, + process_mesh=mesh, + placements=placements, + place=place, + ) + def master_grad_hook(): # NOTE(lizhiyu): To support semi-auto of dygraph mode, we must # classify the params of model into different classes according to their process_mesh. @@ -674,17 +693,48 @@ def master_grad_hook(): param.process_mesh ].append(param) amp_global_state().already_classify_params_meshes = True - - if len(amp_global_state().mesh2params): - for _, params in amp_global_state().mesh2params.items(): - core.eager.set_master_grads(params) - else: - core.eager.set_master_grads( - amp_global_state().model_parameters - ) + if not os.getenv("FLAGS_enable_tensor_fusion") == '1': + if len(amp_global_state().mesh2params): + for _, params in amp_global_state().mesh2params.items(): + core.eager.set_master_grads(params) + else: + core.eager.set_master_grads( + amp_global_state().model_parameters + ) amp_global_state().already_register_final_backward_hook = False + def _update_main_grad_hook(param): + @paddle.autograd.no_grad() + def param_hook(tmp_grad): + if tmp_grad is not None and tmp_grad._is_initialized(): + if param.main_grad is None: + tmp = core.eager.Tensor( + value=tmp_grad._local_value() + .cast(paddle.float32) + .value(), + place=tmp_grad.place, + name="main_grad@" + param.name, + ) + param.main_grad = _dtensor_from_local( + tmp, + tmp_grad.process_mesh, + tmp_grad.placements, + ) + else: + param.main_grad._local_value().add_( + tmp_grad._local_value() + ) + tmp_grad._clear_data() + + return param_hook + + if os.getenv("FLAGS_enable_tensor_fusion") == '1': + for param in amp_global_state().model_parameters: + if not hasattr(param, "main_grad"): + param.main_grad = None + param._register_grad_hook(_update_main_grad_hook(param)) + core.eager._add_backward_final_hook(master_grad_hook) amp_global_state().already_register_final_backward_hook = True diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index 37051e1fe2997..65dcbf2887587 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -1128,6 +1128,10 @@ def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1): for param in self._inner_opt._parameter_list: self._shard_fn._shard_parameter(param) + self.enable_tensor_fusion = ( + os.getenv("FLAGS_enable_tensor_fusion") == '1' + ) + def _set_and_check_sharding_prop_from_param(self): global_mesh = fleet.auto.get_mesh() if global_mesh: @@ -1253,6 +1257,9 @@ def _create_accumulators(self, block, parameters): def _finish_update(self, block, parameters_and_grads): self._inner_opt._finish_update(block, parameters_and_grads) + if self.enable_tensor_fusion: + for param, _ in parameters_and_grads: + param.main_grad._local_value().zero_() if isinstance(parameters_and_grads, list): for p, _ in parameters_and_grads: self._reset_placements(p) diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index d2dce1b588d90..e180d9ba0b1aa 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -2000,9 +2000,16 @@ def step(self) -> None: for param in self._param_groups: if param.stop_gradient: continue - if param._grad_ivar() is not None: - grad_var = param._grad_ivar() - params_grads.append((param, grad_var)) + if os.getenv("FLAGS_enable_tensor_fusion") == '1': + if ( + hasattr(param, "main_grad") + and param.main_grad is not None + ): + params_grads.append((param, param.main_grad)) + else: + if param._grad_ivar() is not None: + grad_var = param._grad_ivar() + params_grads.append((param, grad_var)) self._apply_optimize( loss=None, diff --git a/test/auto_parallel/semi_auto_parallel_sharding_stage_1.py b/test/auto_parallel/semi_auto_parallel_sharding_stage_1.py index 061da8a7978f0..cd2114df69dc8 100644 --- a/test/auto_parallel/semi_auto_parallel_sharding_stage_1.py +++ b/test/auto_parallel/semi_auto_parallel_sharding_stage_1.py @@ -173,6 +173,36 @@ def test_sharding_stage_1_overlap_to_static(self): for batch_id, (image, label) in enumerate(dist_loader()): loss = dist_model(image, label) + def test_pure_sharding_multi_mesh_stage_1_with_tensor_fusion(self): + def run_sharding_test(enable_tensor_fusion): + os.environ['FLAGS_enable_tensor_fusion'] = ( + '1' if enable_tensor_fusion else '0' + ) + paddle.distributed.auto_parallel.set_mesh(self._multi_dim_mesh) + paddle.seed(self._seed) + model = paddle.nn.Linear(10, 10) + batch = paddle.rand(shape=[10, 10]) + batch = dist.shard_tensor(batch, self._mesh, [dist.Shard(0)]) + opt = paddle.optimizer.AdamW(parameters=model.parameters()) + opt = dist.shard_optimizer( + opt, dist.ShardingStage1(sharding_mesh_dim="dp") + ) + model, opt = paddle.amp.decorate( + model, optimizers=opt, level='O2', master_grad=True + ) + for _ in range(5): + with paddle.amp.auto_cast(level='O2'): + loss = model(batch) + loss.backward() + opt.step() + opt.clear_grad() + return loss.numpy() + + dist.init_parallel_env() + loss_disable = run_sharding_test(enable_tensor_fusion=False) + loss_enable = run_sharding_test(enable_tensor_fusion=True) + self.check_tensor_eq(loss_disable, loss_enable) + def run_test_case(self): if self._backend == "cpu": paddle.set_device("cpu") @@ -188,6 +218,7 @@ def run_test_case(self): self.test_sharding_stage_1_to_static() self.test_pure_sharding_multi_mesh_stage_1() self.test_sharding_stage_1_overlap_to_static() + self.test_pure_sharding_multi_mesh_stage_1_with_tensor_fusion() if __name__ == '__main__':