Skip to content

Commit 63a9a1c

Browse files
authored
[AutoParallel] Reconstruct sharding mesh dimension inference logic - Part1 rename func parameter (#69393)
* rename * update * update test * rename * Update semi_auto_parallel_sharding_stage_1.py * fix
1 parent 407b1b8 commit 63a9a1c

File tree

3 files changed

+42
-46
lines changed

3 files changed

+42
-46
lines changed

python/paddle/distributed/auto_parallel/api.py

+34-40
Original file line numberDiff line numberDiff line change
@@ -992,7 +992,7 @@ def replicate_layer_params_and_buffers(
992992
)
993993

994994

995-
def get_placement_with_sharding(param, sharding_mesh_axis):
995+
def get_placement_with_sharding(param, sharding_axis):
996996
shard_axis = -1
997997
for placement in param.placements:
998998
if isinstance(placement, dist.Shard):
@@ -1011,7 +1011,7 @@ def get_placement_with_sharding(param, sharding_mesh_axis):
10111011

10121012
new_placements = param.placements
10131013
if placement_with_sharding is not None:
1014-
new_placements[sharding_mesh_axis] = placement_with_sharding
1014+
new_placements[sharding_axis] = placement_with_sharding
10151015

10161016
return new_placements
10171017

@@ -1040,15 +1040,15 @@ def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1):
10401040
):
10411041
self._shard_clip = True
10421042
self._shard_fn = shard_fn
1043-
self._sharding_mesh_axis = None
1043+
self._sharding_axis = None
10441044
self._sharding_degree = None
10451045
self.gradient_accumulation_steps = gradient_accumulation_steps
10461046

10471047
if isinstance(
10481048
self._shard_fn, (ShardingStage1, ShardingStage2, ShardingStage3)
10491049
):
10501050
self._set_and_check_sharding_prop_from_param()
1051-
self._shard_fn._set_sharding_mesh_axis(self._sharding_mesh_axis)
1051+
self._shard_fn._set_sharding_axis(self._sharding_axis)
10521052

10531053
# Invoke register hook for sharding stage 2 strategy
10541054
if isinstance(self._shard_fn, ShardingStage2):
@@ -1066,12 +1066,12 @@ def _set_and_check_sharding_prop_from_param(self):
10661066
len(self._shard_fn._mesh._shape) == 1
10671067
):
10681068
self._sharding_degree = self._shard_fn._mesh.get_dim_size(0)
1069-
self._sharding_mesh_axis = 0
1069+
self._sharding_axis = 0
10701070
elif (self._shard_fn._mesh is not None) and (
10711071
'dp' in self._shard_fn._mesh.dim_names
10721072
):
10731073
self._sharding_degree = self._shard_fn._mesh.get_dim_size('dp')
1074-
self._sharding_mesh_axis = 0
1074+
self._sharding_axis = 0
10751075
else:
10761076
param_list = self._inner_opt._parameter_list
10771077
for param in param_list:
@@ -1090,7 +1090,7 @@ def _set_and_check_sharding_prop_from_param(self):
10901090
for idx, placement in enumerate(placements):
10911091
if isinstance(placement, dist.Replicate):
10921092
self._sharding_degree = mesh.dim_size(idx)
1093-
self._sharding_mesh_axis = idx
1093+
self._sharding_axis = idx
10941094
break
10951095
elif any(
10961096
isinstance(placement, dist.Partial)
@@ -1100,12 +1100,12 @@ def _set_and_check_sharding_prop_from_param(self):
11001100
else:
11011101
# check the placement on sharding axis is Replicate
11021102
assert isinstance(
1103-
placements[self._sharding_mesh_axis], dist.Replicate
1104-
), "The placement on sharding_mesh_axis should be Replicate"
1103+
placements[self._sharding_axis], dist.Replicate
1104+
), "The placement on sharding_axis should be Replicate"
11051105

11061106
# check the sharding degree since it has already been set
11071107
assert (
1108-
mesh.dim_size(self._sharding_mesh_axis)
1108+
mesh.dim_size(self._sharding_axis)
11091109
== self._sharding_degree
11101110
), "The sharding degree of all parameters must be equal currently."
11111111

@@ -1116,9 +1116,9 @@ def _set_and_check_sharding_prop_from_param(self):
11161116
if self._sharding_degree is None and all_params_replicated_on_each_mesh:
11171117
global_mesh = fleet.auto.get_mesh()
11181118
self._sharding_degree = global_mesh.get_dim_size(
1119-
self._shard_fn._shard_dims
1119+
self._shard_fn._sharding_mesh_dim
11201120
)
1121-
self._sharding_mesh_axis = 0
1121+
self._sharding_axis = 0
11221122

11231123
assert (
11241124
self._sharding_degree is not None
@@ -1176,7 +1176,7 @@ def _reset_placements(self, param):
11761176
# in pir mode, reshard pass will automatically handle inplace case, so no extra work is required here.
11771177
if not isinstance(param, pir.Value):
11781178
new_placement = param.placements
1179-
new_placement[self._sharding_mesh_axis] = dist.Replicate()
1179+
new_placement[self._sharding_axis] = dist.Replicate()
11801180
out_param = dist.reshard(
11811181
param, param.process_mesh, new_placement
11821182
)
@@ -1322,13 +1322,13 @@ def __setattr__(self, item, value):
13221322

13231323

13241324
class _ShardingStageBase:
1325-
def __init__(self, mesh, shard_dims, shard_axis):
1325+
def __init__(self, mesh, sharding_mesh_dim):
13261326
self._mesh = mesh
1327-
self._sharding_mesh_axis = shard_axis
1328-
self._shard_dims = shard_dims
1327+
self._sharding_axis = 0
1328+
self._sharding_mesh_dim = sharding_mesh_dim
13291329

1330-
def _set_sharding_mesh_axis(self, sharding_mesh_axis):
1331-
self._sharding_mesh_axis = sharding_mesh_axis
1330+
def _set_sharding_axis(self, sharding_axis):
1331+
self._sharding_axis = sharding_axis
13321332

13331333
def shard_master_weight(
13341334
self, param: Tensor, master_weight: Tensor
@@ -1340,7 +1340,7 @@ def shard_master_weight(
13401340
data_op.name() == "pd_op.data"
13411341
), "The master weight must be a result of data op."
13421342
placements = get_placement_with_sharding(
1343-
param, self._sharding_mesh_axis
1343+
param, self._sharding_axis
13441344
)
13451345
dim_map, partial_status = to_dim_map(
13461346
placements, len(master_weight.shape)
@@ -1368,8 +1368,7 @@ class ShardingStage1(_ShardingStageBase):
13681368
13691369
Args:
13701370
mesh(None|paddle.distributed.ProcessMesh): If mesh is not None, the `ProcessMesh` object describes the Cartesian topology of the used processes for dense type parameters. Note: Currently, only one mesh configuration is supported for all dense parameters. If there is a need for multiple mesh configurations, please configure them yourself in the upper layer networking code.
1371-
shard_dims(None|int|str): The sharding dimension in the mesh.
1372-
shard_axis(int): The sharding axis of the weight tensor.
1371+
sharding_mesh_dim(None|int|str): The sharding dimension in the mesh.
13731372
13741373
Examples:
13751374
.. code-block:: python
@@ -1405,17 +1404,16 @@ class ShardingStage1(_ShardingStageBase):
14051404
def __init__(
14061405
self,
14071406
mesh: ProcessMesh | None = None,
1408-
shard_dims: int | str | None = None,
1409-
shard_axis: int = 0,
1407+
sharding_mesh_dim: int | str | None = None,
14101408
) -> None:
1411-
super().__init__(mesh, shard_dims, shard_axis)
1409+
super().__init__(mesh, sharding_mesh_dim)
14121410

14131411
def __call__(self, key: str, param: Tensor, accumulator: Tensor) -> Tensor:
14141412
if param.is_dist():
14151413
# Only deal with momentum in optimizer, beta should be replicated cross param's mesh
14161414
if 'beta' not in key:
14171415
placements = get_placement_with_sharding(
1418-
param, self._sharding_mesh_axis
1416+
param, self._sharding_axis
14191417
)
14201418
else:
14211419
placements = [
@@ -1462,8 +1460,7 @@ class ShardingStage2(_ShardingStageBase):
14621460
14631461
Args:
14641462
mesh(None|paddle.distributed.ProcessMesh): If mesh is not None, the `ProcessMesh` object describes the Cartesian topology of the used processes for dense type parameters. Note: Currently, only one mesh configuration is supported for all dense parameters. If there is a need for multiple mesh configurations, please configure them yourself in the upper layer networking code.
1465-
shard_dims(None|int|str): The sharding dimension name in the mesh.
1466-
shard_axis(int): The sharding axis of the weight tensor.
1463+
sharding_mesh_dim(None|int|str): The sharding dimension name in the mesh.
14671464
14681465
Examples:
14691466
.. code-block:: python
@@ -1499,17 +1496,16 @@ class ShardingStage2(_ShardingStageBase):
14991496
def __init__(
15001497
self,
15011498
mesh: ProcessMesh | None = None,
1502-
shard_dims: int | str | None = None,
1503-
shard_axis: int = 0,
1499+
sharding_mesh_dim: int | str | None = None,
15041500
) -> None:
1505-
super().__init__(mesh, shard_dims, shard_axis)
1501+
super().__init__(mesh, sharding_mesh_dim)
15061502

15071503
def __call__(self, key: str, param: Tensor, accumulator: Tensor) -> Tensor:
15081504
if param.is_dist():
15091505
# Only deal with momentum in optimizer, beta should be replicated cross param's mesh
15101506
if 'beta' not in key:
15111507
placements = get_placement_with_sharding(
1512-
param, self._sharding_mesh_axis
1508+
param, self._sharding_axis
15131509
)
15141510
else:
15151511
placements = [
@@ -1580,8 +1576,7 @@ class ShardingStage3(_ShardingStageBase):
15801576
15811577
Args:
15821578
mesh(None|paddle.distributed.ProcessMesh): If mesh is not None, the `ProcessMesh` object describes the Cartesian topology of the used processes for dense type parameters. Note: Currently, only one mesh configuration is supported for all dense parameters. If there is a need for multiple mesh configurations, please configure them yourself in the upper layer networking code.
1583-
shard_dims(None|int|str): The sharding dimension name in the mesh.
1584-
shard_axis(int): The sharding axis of the weight tensor.
1579+
sharding_mesh_dim(None|int|str): The sharding dimension name in the mesh.
15851580
15861581
Examples:
15871582
.. code-block:: python
@@ -1617,10 +1612,9 @@ class ShardingStage3(_ShardingStageBase):
16171612
def __init__(
16181613
self,
16191614
mesh: ProcessMesh | None = None,
1620-
shard_dims: int | str | None = None,
1621-
shard_axis: int = 0,
1615+
sharding_mesh_dim: int | str | None = None,
16221616
) -> None:
1623-
super().__init__(mesh, shard_dims, shard_axis)
1617+
super().__init__(mesh, sharding_mesh_dim)
16241618

16251619
def _shard_parameter(self, param):
16261620
if param.is_dense() and self._mesh is not None:
@@ -1630,7 +1624,7 @@ def _shard_parameter(self, param):
16301624
param._to_dist_(placements, self._mesh)
16311625
if param.is_dist():
16321626
new_placements = get_placement_with_sharding(
1633-
param, self._sharding_mesh_axis
1627+
param, self._sharding_axis
16341628
)
16351629
shard_param = dist.reshard(
16361630
param, param.process_mesh, new_placements
@@ -1641,8 +1635,8 @@ def _shard_parameter(self, param):
16411635
def _unshard_parameter(self, param):
16421636
if param.is_dist():
16431637
new_placements = param.placements
1644-
if isinstance(new_placements[self._sharding_mesh_axis], dist.Shard):
1645-
new_placements[self._sharding_mesh_axis] = dist.Replicate()
1638+
if isinstance(new_placements[self._sharding_axis], dist.Shard):
1639+
new_placements[self._sharding_axis] = dist.Replicate()
16461640

16471641
new_param = dist.reshard(param, param.process_mesh, new_placements)
16481642
param.get_tensor()._share_data_with(new_param.get_tensor())
@@ -1657,7 +1651,7 @@ def __call__(self, key: str, param: Tensor, accumulator: Tensor) -> Tensor:
16571651
for placement in placements
16581652
):
16591653
placements = get_placement_with_sharding(
1660-
param, self._sharding_mesh_axis
1654+
param, self._sharding_axis
16611655
)
16621656

16631657
else:

python/paddle/distributed/auto_parallel/sharding.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def __init__(self, optimizer, shard_fn=None, strategy=None):
106106
else:
107107
self.pp_meshes.add(mesh)
108108

109-
self._sharding_mesh_axis = mesh._dim_names.index("dp")
110-
self._sharding_degree = mesh._shape[self._sharding_mesh_axis]
109+
self._sharding_axis = mesh._dim_names.index("dp")
110+
self._sharding_degree = mesh._shape[self._sharding_axis]
111111
self._mp_mesh_axis = -1
112112
self._mp_degree = 1
113113
if "mp" in mesh._dim_names:
@@ -151,7 +151,7 @@ def apply_gradients(self, params_grads):
151151

152152
if dist.get_rank() in param_dist_attr.process_mesh.process_ids:
153153
sub_mesh = get_1D_sub_process_mesh(
154-
param_dist_attr.process_mesh, self._sharding_mesh_axis
154+
param_dist_attr.process_mesh, self._sharding_axis
155155
)
156156
assert (
157157
sorted(sub_mesh.process_ids) == self._sharding_group.ranks
@@ -178,7 +178,7 @@ def apply_gradients(self, params_grads):
178178
param._local_shape == grad._local_shape
179179
), f"Parameter and grad should have same local shape. but received name:{param.name}, parameter:{param}, grad: {grad}."
180180

181-
if self._sharding_mesh_axis not in grad_dist_attr.partial_dims:
181+
if self._sharding_axis not in grad_dist_attr.partial_dims:
182182
new_params_grads.append((param, grad))
183183
if param.optimize_attr is None:
184184
param.optimize_attr = {'no_fusion': True}
@@ -361,7 +361,7 @@ def apply_gradients(self, params_grads):
361361
partail_status = (
362362
group_grad_list[index].dist_attr().partial_status
363363
)
364-
partail_status.pop(self._sharding_mesh_axis)
364+
partail_status.pop(self._sharding_axis)
365365
slice_grad_dist_attr = pir.create_tensor_dist_attribute(
366366
slice_grad.process_mesh, [-1], partail_status
367367
)

test/auto_parallel/semi_auto_parallel_sharding_stage_1.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def test_pure_sharding_multi_mesh_stage_1(self):
7474
batch = dist.shard_tensor(batch, self._mesh, [dist.Shard(0)])
7575
# shard optimizer with stage 1 fn
7676
opt = paddle.optimizer.AdamW(parameters=linear.parameters())
77-
opt = dist.shard_optimizer(opt, dist.ShardingStage1(shard_dims="dp"))
77+
opt = dist.shard_optimizer(
78+
opt, dist.ShardingStage1(sharding_mesh_dim="dp")
79+
)
7880
for _ in range(5):
7981
loss = linear(batch)
8082
loss.backward()

0 commit comments

Comments
 (0)