Skip to content

Commit d96a81b

Browse files
authored
Update auto_cast.py
1 parent 9c4289f commit d96a81b

File tree

1 file changed

+48
-1
lines changed

1 file changed

+48
-1
lines changed

python/paddle/amp/auto_cast.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,33 @@ def amp_guard(
655655
and not amp_global_state().already_register_final_backward_hook
656656
):
657657

658+
def _dtensor_from_local(
659+
local_tensor, mesh, placements, local_tensor_shape=None
660+
):
661+
# assume the each rank has the same tensor shape for now, just use the local shape to calculate the global shape
662+
global_dims = list(local_tensor.shape)
663+
if local_tensor_shape is not None:
664+
global_dims = local_tensor_shape
665+
for idx, placement in enumerate(placements):
666+
if placement.is_shard():
667+
shard_dim = placement.get_dim()
668+
local_dim_size = global_dims[shard_dim]
669+
global_dims[shard_dim] = (
670+
local_dim_size * mesh.shape[idx]
671+
)
672+
673+
if paddle.in_dynamic_mode():
674+
place = paddle.framework._current_expected_place()
675+
place = paddle.framework._get_paddle_place(place)
676+
677+
return paddle.Tensor(
678+
local_tensor,
679+
dims=global_dims,
680+
process_mesh=mesh,
681+
placements=placements,
682+
place=place,
683+
)
684+
658685
def master_grad_hook():
659686
# NOTE(lizhiyu): To support semi-auto of dygraph mode, we must
660687
# classify the params of model into different classes according to their process_mesh.
@@ -677,7 +704,27 @@ def master_grad_hook():
677704

678705
if len(amp_global_state().mesh2params):
679706
for _, params in amp_global_state().mesh2params.items():
680-
core.eager.set_master_grads(params)
707+
for param in params:
708+
tmp_grad = param._grad_ivar()
709+
if param.main_grad is None:
710+
tmp = core.eager.Tensor(
711+
value=tmp_grad._local_value()
712+
.cast(paddle.float32)
713+
.value(),
714+
place=tmp_grad.place,
715+
name="main_grad@" + param.name,
716+
)
717+
param.main_grad = _dtensor_from_local(
718+
tmp,
719+
tmp_grad.process_mesh,
720+
tmp_grad.placements,
721+
)
722+
else:
723+
param.main_grad._local_value().add_(
724+
tmp_grad._local_value()
725+
)
726+
727+
# core.eager.set_master_grads(params)
681728
else:
682729
core.eager.set_master_grads(
683730
amp_global_state().model_parameters

0 commit comments

Comments
 (0)