Skip to content

Commit a8980c1

Browse files
authored
skip_sharding_check_in_moe (#71292)
1 parent 3fc4ca3 commit a8980c1

File tree

1 file changed

+14
-4
lines changed
  • python/paddle/distributed/auto_parallel

1 file changed

+14
-4
lines changed

python/paddle/distributed/auto_parallel/api.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,10 +1140,20 @@ def _set_and_check_sharding_prop_from_param(self):
11401140
placements[self._sharding_axis], dist.Replicate
11411141
), "The placement on sharding_axis should be Replicate"
11421142

1143-
# check the sharding degree since it has already been set
1144-
assert (
1145-
mesh.dim_size(self._sharding_axis) == self._sharding_degree
1146-
), "The sharding degree of all parameters must be equal currently."
1143+
# check the sharding degree since it has already been set,
1144+
# skip check when mesh is true subset of global_mesh
1145+
if global_mesh:
1146+
if set(mesh.process_ids) < set(global_mesh.process_ids):
1147+
continue
1148+
elif self._shard_fn._mesh:
1149+
if set(mesh.process_ids) < set(
1150+
self._shard_fn._mesh.process_ids
1151+
):
1152+
continue
1153+
else:
1154+
assert (
1155+
mesh.dim_size(self._sharding_axis) == self._sharding_degree
1156+
), "The sharding degree of all parameters must be equal currently."
11471157

11481158
def _shard_accumulator(self, param):
11491159
target_name = param.name

0 commit comments

Comments
 (0)