Skip to content

Commit 9b7cd3d

Browse files
authored
[Auto Parallel]Add spmd rule for five ops(bitwise_or,atan2,fmax,fmin,reciprocal). (#72310)
* add unary ops which have spmd_rule but not add in yaml file. * Add spmd rule for five ops(bitwise_or,atan2,fmax,fmin,reciprocal). * add spmd_rule for atan2_grad, fmax_grad, fmin_grad.
1 parent 08beb1e commit 9b7cd3d

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

paddle/phi/ops/yaml/backward.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@
215215
infer_meta :
216216
func : GeneralBinaryGradInferMeta
217217
param : [x, y]
218+
spmd_rule : ElementwiseBinaryGradInferSpmd
218219
kernel :
219220
func : atan2_grad
220221

@@ -1230,6 +1231,7 @@
12301231
infer_meta :
12311232
func : GeneralBinaryGradInferMeta
12321233
param: [x, y]
1234+
spmd_rule : ElementwiseBinaryGradInferSpmd
12331235
kernel :
12341236
func : fmax_grad
12351237
data_type : out_grad
@@ -1241,6 +1243,7 @@
12411243
infer_meta :
12421244
func : GeneralBinaryGradInferMeta
12431245
param: [x, y]
1246+
spmd_rule : ElementwiseBinaryGradInferSpmd
12441247
kernel :
12451248
func : fmin_grad
12461249
data_type : out_grad
@@ -2602,6 +2605,7 @@
26022605
infer_meta :
26032606
func : UnchangedInferMeta
26042607
param : [out]
2608+
spmd_rule : ElementwiseUnaryGradInferSpmd
26052609
kernel :
26062610
func : reciprocal_grad
26072611
inplace : (out_grad -> x_grad)

paddle/phi/ops/yaml/ops.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,7 @@
450450
output : Tensor(out)
451451
infer_meta :
452452
func : Atan2InferMeta
453+
spmd_rule : ElementwiseBinaryInferSpmd
453454
kernel :
454455
func : atan2
455456
backward : atan2_grad
@@ -689,6 +690,7 @@
689690
output : Tensor(out)
690691
infer_meta :
691692
func : ElementwiseInferMeta
693+
spmd_rule : ElementwiseBinaryInferSpmd
692694
kernel :
693695
func : bitwise_or
694696
backend : x
@@ -2074,6 +2076,7 @@
20742076
infer_meta :
20752077
param: [x, y]
20762078
func : ElementwiseInferMeta
2079+
spmd_rule : ElementwiseBinaryInferSpmd
20772080
kernel :
20782081
func : fmax
20792082
backward : fmax_grad
@@ -2086,6 +2089,7 @@
20862089
infer_meta :
20872090
func : ElementwiseInferMeta
20882091
param: [x, y]
2092+
spmd_rule : ElementwiseBinaryInferSpmd
20892093
kernel :
20902094
func : fmin
20912095
backward : fmin_grad
@@ -4092,6 +4096,7 @@
40924096
output : Tensor(out)
40934097
infer_meta :
40944098
func : UnchangedInferMeta
4099+
spmd_rule : ElementwiseUnaryInferSpmd
40954100
kernel :
40964101
func : reciprocal
40974102
inplace : (x -> out)

0 commit comments

Comments
 (0)