You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
), "The master weight must be a result of data op."
1342
1342
placements=get_placement_with_sharding(
1343
-
param, self._sharding_mesh_axis
1343
+
param, self._sharding_axis
1344
1344
)
1345
1345
dim_map, partial_status=to_dim_map(
1346
1346
placements, len(master_weight.shape)
@@ -1368,8 +1368,7 @@ class ShardingStage1(_ShardingStageBase):
1368
1368
1369
1369
Args:
1370
1370
mesh(None|paddle.distributed.ProcessMesh): If mesh is not None, the `ProcessMesh` object describes the Cartesian topology of the used processes for dense type parameters. Note: Currently, only one mesh configuration is supported for all dense parameters. If there is a need for multiple mesh configurations, please configure them yourself in the upper layer networking code.
1371
-
shard_dims(None|int|str): The sharding dimension in the mesh.
1372
-
shard_axis(int): The sharding axis of the weight tensor.
1371
+
sharding_mesh_dim(None|int|str): The sharding dimension in the mesh.
1373
1372
1374
1373
Examples:
1375
1374
.. code-block:: python
@@ -1405,17 +1404,16 @@ class ShardingStage1(_ShardingStageBase):
# Only deal with momentum in optimizer, beta should be replicated cross param's mesh
1416
1414
if'beta'notinkey:
1417
1415
placements=get_placement_with_sharding(
1418
-
param, self._sharding_mesh_axis
1416
+
param, self._sharding_axis
1419
1417
)
1420
1418
else:
1421
1419
placements= [
@@ -1462,8 +1460,7 @@ class ShardingStage2(_ShardingStageBase):
1462
1460
1463
1461
Args:
1464
1462
mesh(None|paddle.distributed.ProcessMesh): If mesh is not None, the `ProcessMesh` object describes the Cartesian topology of the used processes for dense type parameters. Note: Currently, only one mesh configuration is supported for all dense parameters. If there is a need for multiple mesh configurations, please configure them yourself in the upper layer networking code.
1465
-
shard_dims(None|int|str): The sharding dimension name in the mesh.
1466
-
shard_axis(int): The sharding axis of the weight tensor.
1463
+
sharding_mesh_dim(None|int|str): The sharding dimension name in the mesh.
1467
1464
1468
1465
Examples:
1469
1466
.. code-block:: python
@@ -1499,17 +1496,16 @@ class ShardingStage2(_ShardingStageBase):
# Only deal with momentum in optimizer, beta should be replicated cross param's mesh
1510
1506
if'beta'notinkey:
1511
1507
placements=get_placement_with_sharding(
1512
-
param, self._sharding_mesh_axis
1508
+
param, self._sharding_axis
1513
1509
)
1514
1510
else:
1515
1511
placements= [
@@ -1580,8 +1576,7 @@ class ShardingStage3(_ShardingStageBase):
1580
1576
1581
1577
Args:
1582
1578
mesh(None|paddle.distributed.ProcessMesh): If mesh is not None, the `ProcessMesh` object describes the Cartesian topology of the used processes for dense type parameters. Note: Currently, only one mesh configuration is supported for all dense parameters. If there is a need for multiple mesh configurations, please configure them yourself in the upper layer networking code.
1583
-
shard_dims(None|int|str): The sharding dimension name in the mesh.
1584
-
shard_axis(int): The sharding axis of the weight tensor.
1579
+
sharding_mesh_dim(None|int|str): The sharding dimension name in the mesh.
1585
1580
1586
1581
Examples:
1587
1582
.. code-block:: python
@@ -1617,10 +1612,9 @@ class ShardingStage3(_ShardingStageBase):
0 commit comments