File tree 1 file changed +3
-9
lines changed
1 file changed +3
-9
lines changed Original file line number Diff line number Diff line change @@ -656,18 +656,12 @@ def amp_guard(
656
656
and not amp_global_state ().already_register_final_backward_hook
657
657
):
658
658
659
- def _dtensor_from_local (
660
- local_tensor , mesh , placements , local_tensor_shape = None
661
- ):
659
+ def _dtensor_from_local (local_tensor , mesh , placements ):
662
660
global_dims = list (local_tensor .shape )
663
- if local_tensor_shape is not None :
664
- global_dims = local_tensor_shape
665
661
for idx , placement in enumerate (placements ):
666
662
if placement .is_shard ():
667
- shard_dim = placement .get_dim ()
668
- local_dim_size = global_dims [shard_dim ]
669
- global_dims [shard_dim ] = (
670
- local_dim_size * mesh .shape [idx ]
663
+ global_dims [placement .get_dim ()] = (
664
+ global_dims [placement .get_dim ()] * mesh .shape [idx ]
671
665
)
672
666
place = paddle .framework ._current_expected_place ()
673
667
place = paddle .framework ._get_paddle_place (place )
You can’t perform that action at this time.
0 commit comments