Skip to content

Commit 388364d

Browse files
authored
Update auto_cast.py
1 parent 0459942 commit 388364d

File tree

1 file changed

+3
-9
lines changed

1 file changed

+3
-9
lines changed

python/paddle/amp/auto_cast.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -656,18 +656,12 @@ def amp_guard(
656656
and not amp_global_state().already_register_final_backward_hook
657657
):
658658

659-
def _dtensor_from_local(
660-
local_tensor, mesh, placements, local_tensor_shape=None
661-
):
659+
def _dtensor_from_local(local_tensor, mesh, placements):
662660
global_dims = list(local_tensor.shape)
663-
if local_tensor_shape is not None:
664-
global_dims = local_tensor_shape
665661
for idx, placement in enumerate(placements):
666662
if placement.is_shard():
667-
shard_dim = placement.get_dim()
668-
local_dim_size = global_dims[shard_dim]
669-
global_dims[shard_dim] = (
670-
local_dim_size * mesh.shape[idx]
663+
global_dims[placement.get_dim()] = (
664+
global_dims[placement.get_dim()] * mesh.shape[idx]
671665
)
672666
place = paddle.framework._current_expected_place()
673667
place = paddle.framework._get_paddle_place(place)

0 commit comments

Comments
 (0)