Skip to content

Commit 555ad25

Browse files
authored
[Prim][PIR] Surport dynamic shape for softmax_grad (#65961)
* surport dynamic shape for the softmax_grad * add the configure
1 parent c57faca commit 555ad25

File tree

3 files changed

+43
-1
lines changed

3 files changed

+43
-1
lines changed

paddle/fluid/primitive/rule/vjp/details.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -1084,7 +1084,8 @@ void softmax_grad(const Tensor& out,
10841084
set_output<T>(tmp_x_grad, x_grad);
10851085
}
10861086
} else {
1087-
set_output<T>(out_grad * 0.0, x_grad);
1087+
Tensor zeros = full_scalar<T>(0.0, out.dtype());
1088+
set_output<T>(out_grad * zeros, x_grad);
10881089
}
10891090
}
10901091
}

python/paddle/autograd/backward_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
"pd_op.divide",
5454
"pd_op.pow",
5555
"pd_op.elementwise_pow",
56+
"pd_op.softmax",
5657
]
5758

5859

test/prim/pir_prim/test_prim_sub_graph_backward_dynamic_shape.py

+40
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ def pow_net(x):
115115
return paddle.pow(x, 3.2)
116116

117117

118+
def softmax_net(x):
119+
return paddle.nn.functional.softmax(x, axis=-1)
120+
121+
118122
def apply_to_static(net, use_cinn, input_spec=None):
119123
build_strategy = paddle.static.BuildStrategy()
120124
build_strategy.build_cinn_pass = use_cinn
@@ -1330,5 +1334,41 @@ def setUp(self):
13301334
self.tol = 1e-6
13311335

13321336

1337+
class TestPrimSoftmaxWithGrad1(TestPrimBaseWithGrad):
1338+
def setUp(self):
1339+
np.random.seed(2023)
1340+
self.dtype = "float32"
1341+
self.x_shape = [30, 200, 40]
1342+
self.init_x_shape = [None, None, None]
1343+
self.x = np.random.random(self.x_shape).astype(self.dtype)
1344+
self.net = softmax_net
1345+
self.enable_cinn = False
1346+
self.tol = 1e-6
1347+
1348+
1349+
class TestPrimSoftmaxWithGrad2(TestPrimBaseWithGrad):
1350+
def setUp(self):
1351+
np.random.seed(2023)
1352+
self.dtype = "float32"
1353+
self.x_shape = [30, 200, 40]
1354+
self.init_x_shape = [None, None, 40]
1355+
self.x = np.random.random(self.x_shape).astype(self.dtype)
1356+
self.net = softmax_net
1357+
self.enable_cinn = False
1358+
self.tol = 1e-6
1359+
1360+
1361+
class TestPrimSoftmaxWithGrad3(TestPrimBaseWithGrad):
1362+
def setUp(self):
1363+
np.random.seed(2023)
1364+
self.dtype = "float32"
1365+
self.x_shape = [30, 200, 40]
1366+
self.init_x_shape = [30, 200, None]
1367+
self.x = np.random.random(self.x_shape).astype(self.dtype)
1368+
self.net = softmax_net
1369+
self.enable_cinn = False
1370+
self.tol = 1e-6
1371+
1372+
13331373
if __name__ == "__main__":
13341374
unittest.main()

0 commit comments

Comments
 (0)