@@ -217,5 +217,49 @@ SpmdInfo AdamwInferSpmdDynamic(const DistMetaTensor& param,
217
217
use_global_beta_pow);
218
218
}
219
219
220
+ SpmdInfo SgdInferSpmd (const DistMetaTensor& param,
221
+ const DistMetaTensor& learning_rate,
222
+ const DistMetaTensor& grad,
223
+ const DistMetaTensor& master_param,
224
+ bool multi_precision) {
225
+ SpmdInfo param_grad_spmd = ElementwiseBinaryInferSpmd (param, grad);
226
+ TensorDistAttr param_dist_attr_spmd =
227
+ PADDLE_GET (TensorDistAttr, param_grad_spmd.first [0 ]);
228
+ TensorDistAttr grad_dist_attr_spmd =
229
+ PADDLE_GET (TensorDistAttr, param_grad_spmd.first [1 ]);
230
+
231
+ VLOG (3 ) << " The source dims mapping for param is: "
232
+ << auto_parallel::str_join (param.dist_attr ().dims_mapping ());
233
+ VLOG (3 ) << " The source dims mapping for grad is: "
234
+ << auto_parallel::str_join (grad.dist_attr ().dims_mapping ());
235
+ VLOG (3 ) << " The inter dims mapping for param (master param if available) "
236
+ << " after elementwise spmd is: "
237
+ << auto_parallel::str_join (param.dist_attr ().dims_mapping ());
238
+ VLOG (3 ) << " The inter dims mapping for grad after elementwise spmd is: "
239
+ << auto_parallel::str_join (grad.dist_attr ().dims_mapping ());
240
+
241
+ TensorDistAttr param_dist_attr =
242
+ CopyTensorDistAttrForOutput (param_dist_attr_spmd);
243
+ TensorDistAttr grad_dist_attr =
244
+ CopyTensorDistAttrForOutput (grad_dist_attr_spmd);
245
+ TensorDistAttr lr_dist_attr =
246
+ CopyTensorDistAttrForOutput (learning_rate.dist_attr ());
247
+ TensorDistAttr master_param_dist_attr =
248
+ master_param.initialized ()
249
+ ? CopyTensorDistAttrForOutput (master_param.dist_attr ())
250
+ : TensorDistAttr ();
251
+ param_dist_attr.set_dims_mapping (param_dist_attr_spmd.dims_mapping ());
252
+ grad_dist_attr.set_dims_mapping (grad_dist_attr_spmd.dims_mapping ());
253
+ if (master_param.initialized ()) {
254
+ master_param_dist_attr.set_dims_mapping (
255
+ param_dist_attr_spmd.dims_mapping ());
256
+ }
257
+ lr_dist_attr.set_dims_mapping (learning_rate.dist_attr ().dims_mapping ());
258
+
259
+ return {
260
+ {param_dist_attr, lr_dist_attr, grad_dist_attr, master_param_dist_attr},
261
+ {param_dist_attr, master_param_dist_attr}};
262
+ }
263
+
220
264
} // namespace distributed
221
265
} // namespace phi
0 commit comments