Skip to content

[Auto Parallel] add main_grad for sharding in auto dy #72493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
66 changes: 58 additions & 8 deletions python/paddle/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import annotations

import copy
import os
import warnings
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -655,6 +656,24 @@ def amp_guard(
and not amp_global_state().already_register_final_backward_hook
):

def _dtensor_from_local(local_tensor, mesh, placements):
global_dims = list(local_tensor.shape)
for idx, placement in enumerate(placements):
if placement.is_shard():
global_dims[placement.get_dim()] = (
global_dims[placement.get_dim()] * mesh.shape[idx]
)
place = paddle.framework._current_expected_place()
place = paddle.framework._get_paddle_place(place)

return paddle.Tensor(
local_tensor,
dims=global_dims,
process_mesh=mesh,
placements=placements,
place=place,
)

def master_grad_hook():
# NOTE(lizhiyu): To support semi-auto of dygraph mode, we must
# classify the params of model into different classes according to their process_mesh.
Expand All @@ -674,17 +693,48 @@ def master_grad_hook():
param.process_mesh
].append(param)
amp_global_state().already_classify_params_meshes = True

if len(amp_global_state().mesh2params):
for _, params in amp_global_state().mesh2params.items():
core.eager.set_master_grads(params)
else:
core.eager.set_master_grads(
amp_global_state().model_parameters
)
if not os.getenv("FLAGS_enable_tensor_fusion") == '1':
if len(amp_global_state().mesh2params):
for _, params in amp_global_state().mesh2params.items():
core.eager.set_master_grads(params)
else:
core.eager.set_master_grads(
amp_global_state().model_parameters
)

amp_global_state().already_register_final_backward_hook = False

def _update_main_grad_hook(param):
@paddle.autograd.no_grad()
def param_hook(tmp_grad):
if tmp_grad is not None and tmp_grad._is_initialized():
if param.main_grad is None:
tmp = core.eager.Tensor(
value=tmp_grad._local_value()
.cast(paddle.float32)
.value(),
place=tmp_grad.place,
name="main_grad@" + param.name,
)
param.main_grad = _dtensor_from_local(
tmp,
tmp_grad.process_mesh,
tmp_grad.placements,
)
else:
param.main_grad._local_value().add_(
tmp_grad._local_value()
)
tmp_grad._clear_data()

return param_hook

if os.getenv("FLAGS_enable_tensor_fusion") == '1':
for param in amp_global_state().model_parameters:
if not hasattr(param, "main_grad"):
param.main_grad = None
param._register_grad_hook(_update_main_grad_hook(param))

core.eager._add_backward_final_hook(master_grad_hook)
amp_global_state().already_register_final_backward_hook = True

Expand Down
7 changes: 7 additions & 0 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,10 @@ def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1):
for param in self._inner_opt._parameter_list:
self._shard_fn._shard_parameter(param)

self.enable_tensor_fusion = (
os.getenv("FLAGS_enable_tensor_fusion") == '1'
)

def _set_and_check_sharding_prop_from_param(self):
global_mesh = fleet.auto.get_mesh()
if global_mesh:
Expand Down Expand Up @@ -1253,6 +1257,9 @@ def _create_accumulators(self, block, parameters):

def _finish_update(self, block, parameters_and_grads):
self._inner_opt._finish_update(block, parameters_and_grads)
if self.enable_tensor_fusion:
for param, _ in parameters_and_grads:
param.main_grad._local_value().zero_()
if isinstance(parameters_and_grads, list):
for p, _ in parameters_and_grads:
self._reset_placements(p)
Expand Down
13 changes: 10 additions & 3 deletions python/paddle/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2000,9 +2000,16 @@ def step(self) -> None:
for param in self._param_groups:
if param.stop_gradient:
continue
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
params_grads.append((param, grad_var))
if os.getenv("FLAGS_enable_tensor_fusion") == '1':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Flag 需要考虑 true True的情况
  2. 不建议在 auto_cast.py api.py optimizer.py 文件中,都用 FLAGS_enable_tensor_fusion 判断,可以改成配置。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,将在后续PR上一起修改

if (
hasattr(param, "main_grad")
and param.main_grad is not None
):
params_grads.append((param, param.main_grad))
else:
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
params_grads.append((param, grad_var))

self._apply_optimize(
loss=None,
Expand Down
31 changes: 31 additions & 0 deletions test/auto_parallel/semi_auto_parallel_sharding_stage_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,36 @@ def test_sharding_stage_1_overlap_to_static(self):
for batch_id, (image, label) in enumerate(dist_loader()):
loss = dist_model(image, label)

def test_pure_sharding_multi_mesh_stage_1_with_tensor_fusion(self):
def run_sharding_test(enable_tensor_fusion):
os.environ['FLAGS_enable_tensor_fusion'] = (
'1' if enable_tensor_fusion else '0'
)
paddle.distributed.auto_parallel.set_mesh(self._multi_dim_mesh)
paddle.seed(self._seed)
model = paddle.nn.Linear(10, 10)
batch = paddle.rand(shape=[10, 10])
batch = dist.shard_tensor(batch, self._mesh, [dist.Shard(0)])
opt = paddle.optimizer.AdamW(parameters=model.parameters())
opt = dist.shard_optimizer(
opt, dist.ShardingStage1(sharding_mesh_dim="dp")
)
model, opt = paddle.amp.decorate(
model, optimizers=opt, level='O2', master_grad=True
)
for _ in range(5):
with paddle.amp.auto_cast(level='O2'):
loss = model(batch)
loss.backward()
opt.step()
opt.clear_grad()
return loss.numpy()

dist.init_parallel_env()
loss_disable = run_sharding_test(enable_tensor_fusion=False)
loss_enable = run_sharding_test(enable_tensor_fusion=True)
self.check_tensor_eq(loss_disable, loss_enable)

def run_test_case(self):
if self._backend == "cpu":
paddle.set_device("cpu")
Expand All @@ -188,6 +218,7 @@ def run_test_case(self):
self.test_sharding_stage_1_to_static()
self.test_pure_sharding_multi_mesh_stage_1()
self.test_sharding_stage_1_overlap_to_static()
self.test_pure_sharding_multi_mesh_stage_1_with_tensor_fusion()


if __name__ == '__main__':
Expand Down
Loading