Skip to content

Commit 91fa5ff

Browse files
authored
[auto parallel] shard optimizer enhance (#59575)
1 parent 39fda14 commit 91fa5ff

File tree

5 files changed

+54
-2
lines changed

5 files changed

+54
-2
lines changed

paddle/phi/api/yaml/ops.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -2304,6 +2304,7 @@
23042304
output : Tensor(param_out), Tensor(master_param_out)
23052305
infer_meta :
23062306
func : SgdInferMeta
2307+
spmd_rule : SgdInferSpmd
23072308
kernel :
23082309
func : sgd {dense, dense, dense, dense -> dense, dense},
23092310
sgd_dense_param_sparse_grad {dense, dense, selected_rows, dense -> dense, dense},

paddle/phi/infermeta/spmd_rules/optimizer.cc

+44
Original file line numberDiff line numberDiff line change
@@ -217,5 +217,49 @@ SpmdInfo AdamwInferSpmdDynamic(const DistMetaTensor& param,
217217
use_global_beta_pow);
218218
}
219219

220+
SpmdInfo SgdInferSpmd(const DistMetaTensor& param,
221+
const DistMetaTensor& learning_rate,
222+
const DistMetaTensor& grad,
223+
const DistMetaTensor& master_param,
224+
bool multi_precision) {
225+
SpmdInfo param_grad_spmd = ElementwiseBinaryInferSpmd(param, grad);
226+
TensorDistAttr param_dist_attr_spmd =
227+
PADDLE_GET(TensorDistAttr, param_grad_spmd.first[0]);
228+
TensorDistAttr grad_dist_attr_spmd =
229+
PADDLE_GET(TensorDistAttr, param_grad_spmd.first[1]);
230+
231+
VLOG(3) << "The source dims mapping for param is: "
232+
<< auto_parallel::str_join(param.dist_attr().dims_mapping());
233+
VLOG(3) << "The source dims mapping for grad is: "
234+
<< auto_parallel::str_join(grad.dist_attr().dims_mapping());
235+
VLOG(3) << "The inter dims mapping for param (master param if available) "
236+
<< "after elementwise spmd is: "
237+
<< auto_parallel::str_join(param.dist_attr().dims_mapping());
238+
VLOG(3) << "The inter dims mapping for grad after elementwise spmd is: "
239+
<< auto_parallel::str_join(grad.dist_attr().dims_mapping());
240+
241+
TensorDistAttr param_dist_attr =
242+
CopyTensorDistAttrForOutput(param_dist_attr_spmd);
243+
TensorDistAttr grad_dist_attr =
244+
CopyTensorDistAttrForOutput(grad_dist_attr_spmd);
245+
TensorDistAttr lr_dist_attr =
246+
CopyTensorDistAttrForOutput(learning_rate.dist_attr());
247+
TensorDistAttr master_param_dist_attr =
248+
master_param.initialized()
249+
? CopyTensorDistAttrForOutput(master_param.dist_attr())
250+
: TensorDistAttr();
251+
param_dist_attr.set_dims_mapping(param_dist_attr_spmd.dims_mapping());
252+
grad_dist_attr.set_dims_mapping(grad_dist_attr_spmd.dims_mapping());
253+
if (master_param.initialized()) {
254+
master_param_dist_attr.set_dims_mapping(
255+
param_dist_attr_spmd.dims_mapping());
256+
}
257+
lr_dist_attr.set_dims_mapping(learning_rate.dist_attr().dims_mapping());
258+
259+
return {
260+
{param_dist_attr, lr_dist_attr, grad_dist_attr, master_param_dist_attr},
261+
{param_dist_attr, master_param_dist_attr}};
262+
}
263+
220264
} // namespace distributed
221265
} // namespace phi

paddle/phi/infermeta/spmd_rules/optimizer.h

+6
Original file line numberDiff line numberDiff line change
@@ -60,5 +60,11 @@ SpmdInfo AdamwInferSpmdDynamic(const DistMetaTensor& param,
6060
bool multi_precision,
6161
bool use_global_beta_pow);
6262

63+
SpmdInfo SgdInferSpmd(const DistMetaTensor& param,
64+
const DistMetaTensor& learning_rate,
65+
const DistMetaTensor& grad,
66+
const DistMetaTensor& master_param,
67+
bool multi_precision);
68+
6369
} // namespace distributed
6470
} // namespace phi

python/paddle/distributed/auto_parallel/api.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -414,8 +414,8 @@ def __init__(self, optimizer, shard_fn=None):
414414
optimizer is not None
415415
), "The argument `optimizer` cannot be empty."
416416
assert isinstance(
417-
optimizer, paddle.optimizer.AdamW
418-
), "`paddle.distributed.ShardOptimizer` only supports AdamW optimizer for now."
417+
optimizer, (paddle.optimizer.AdamW, paddle.optimizer.SGD)
418+
), "`paddle.distributed.ShardOptimizer` only supports AdamW and SGD optimizer for now."
419419

420420
self.target_block = (
421421
paddle.base.framework.default_main_program().global_block()

test/auto_parallel/semi_auto_parallel_simple_net.py

+1
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def run_dynamic(self, layer, shard_input=False, is_pp=False):
130130
opt = paddle.optimizer.SGD(
131131
learning_rate=0.1, parameters=layer.parameters()
132132
)
133+
opt = dist.shard_optimizer(opt)
133134
for _ in range(5):
134135
image, label = self.init_input_data()
135136
if shard_input:

0 commit comments

Comments
 (0)