Skip to content

Commit fce2670

Browse files
authored
[Auto Parallel] Add tensor_fusion and overlap in auto dy sharding (#72551)
1 parent 0ef8ed1 commit fce2670

File tree

6 files changed

+544
-45
lines changed

6 files changed

+544
-45
lines changed

paddle/fluid/pybind/tensor.cc

+2
Original file line numberDiff line numberDiff line change
@@ -1199,6 +1199,8 @@ void BindTensor(pybind11::module &m) { // NOLINT
11991199
self.unsafe_mutable_value()->ShareDataNoCheckWith(src.value());
12001200
return self;
12011201
})
1202+
.def("_numel",
1203+
[](DistTensor &self) -> int64_t { return self.value().numel(); })
12021204
.def("_share_data_with",
12031205
[](DistTensor &self, const DistTensor &src) {
12041206
self.unsafe_set_dims(src.dims());

python/paddle/amp/auto_cast.py

+62-8
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import annotations
1515

1616
import copy
17+
import os
1718
import warnings
1819
from typing import (
1920
TYPE_CHECKING,
@@ -655,6 +656,24 @@ def amp_guard(
655656
and not amp_global_state().already_register_final_backward_hook
656657
):
657658

659+
def _dtensor_from_local(local_tensor, mesh, placements):
660+
global_dims = list(local_tensor.shape)
661+
for idx, placement in enumerate(placements):
662+
if placement.is_shard():
663+
global_dims[placement.get_dim()] = (
664+
global_dims[placement.get_dim()] * mesh.shape[idx]
665+
)
666+
place = paddle.framework._current_expected_place()
667+
place = paddle.framework._get_paddle_place(place)
668+
669+
return paddle.Tensor(
670+
local_tensor,
671+
dims=global_dims,
672+
process_mesh=mesh,
673+
placements=placements,
674+
place=place,
675+
)
676+
658677
def master_grad_hook():
659678
# NOTE(lizhiyu): To support semi-auto of dygraph mode, we must
660679
# classify the params of model into different classes according to their process_mesh.
@@ -674,17 +693,52 @@ def master_grad_hook():
674693
param.process_mesh
675694
].append(param)
676695
amp_global_state().already_classify_params_meshes = True
677-
678-
if len(amp_global_state().mesh2params):
679-
for _, params in amp_global_state().mesh2params.items():
680-
core.eager.set_master_grads(params)
681-
else:
682-
core.eager.set_master_grads(
683-
amp_global_state().model_parameters
684-
)
696+
if os.getenv("FLAGS_enable_tensor_fusion") not in [
697+
"True",
698+
"true",
699+
"1",
700+
]:
701+
if len(amp_global_state().mesh2params):
702+
for _, params in amp_global_state().mesh2params.items():
703+
core.eager.set_master_grads(params)
704+
else:
705+
core.eager.set_master_grads(
706+
amp_global_state().model_parameters
707+
)
685708

686709
amp_global_state().already_register_final_backward_hook = False
687710

711+
def _update_main_grad_hook(param):
712+
@paddle.autograd.no_grad()
713+
def param_hook(tmp_grad):
714+
if tmp_grad is not None and tmp_grad._is_initialized():
715+
if param.main_grad is None:
716+
tmp = core.eager.Tensor(
717+
value=tmp_grad._local_value()
718+
.cast(paddle.float32)
719+
.value(),
720+
place=tmp_grad.place,
721+
name="main_grad@" + param.name,
722+
)
723+
param.main_grad = _dtensor_from_local(
724+
tmp,
725+
tmp_grad.process_mesh,
726+
tmp_grad.placements,
727+
)
728+
else:
729+
param.main_grad._local_value().add_(
730+
tmp_grad._local_value()
731+
)
732+
tmp_grad._clear_data()
733+
734+
return param_hook
735+
736+
if os.getenv("FLAGS_enable_tensor_fusion") in ["True", "true", "1"]:
737+
for param in amp_global_state().model_parameters:
738+
if not hasattr(param, "main_grad"):
739+
param.main_grad = None
740+
param._register_grad_hook(_update_main_grad_hook(param))
741+
688742
core.eager._add_backward_final_hook(master_grad_hook)
689743
amp_global_state().already_register_final_backward_hook = True
690744

0 commit comments

Comments
 (0)