File tree Expand file tree Collapse file tree 1 file changed +14
-4
lines changed
python/paddle/distributed/auto_parallel Expand file tree Collapse file tree 1 file changed +14
-4
lines changed Original file line number Diff line number Diff line change @@ -1140,10 +1140,20 @@ def _set_and_check_sharding_prop_from_param(self):
1140
1140
placements [self ._sharding_axis ], dist .Replicate
1141
1141
), "The placement on sharding_axis should be Replicate"
1142
1142
1143
- # check the sharding degree since it has already been set
1144
- assert (
1145
- mesh .dim_size (self ._sharding_axis ) == self ._sharding_degree
1146
- ), "The sharding degree of all parameters must be equal currently."
1143
+ # check the sharding degree since it has already been set,
1144
+ # skip check when mesh is true subset of global_mesh
1145
+ if global_mesh :
1146
+ if set (mesh .process_ids ) < set (global_mesh .process_ids ):
1147
+ continue
1148
+ elif self ._shard_fn ._mesh :
1149
+ if set (mesh .process_ids ) < set (
1150
+ self ._shard_fn ._mesh .process_ids
1151
+ ):
1152
+ continue
1153
+ else :
1154
+ assert (
1155
+ mesh .dim_size (self ._sharding_axis ) == self ._sharding_degree
1156
+ ), "The sharding degree of all parameters must be equal currently."
1147
1157
1148
1158
def _shard_accumulator (self , param ):
1149
1159
target_name = param .name
You can’t perform that action at this time.
0 commit comments