Skip to content

Commit 9b0618e

Browse files
update UT
1 parent 40de41b commit 9b0618e

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

test/prim/prim/vjp/test_comp_high_grad.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,21 @@ def setUpClass(cls):
164164
def subtract_wrapper(self, x):
165165
return paddle.subtract(x[0], x[1])
166166

167-
@prog_scope()
167+
def test_func_double_eager(self):
168+
shape1 = self.shape1
169+
shape2 = self.shape2
170+
dtype = np.float64
171+
x = paddle.randn(shape1, dtype=dtype)
172+
x.stop_gradient = False
173+
y = paddle.randn(shape2, dtype=dtype)
174+
y.stop_gradient = False
175+
out = paddle.subtract(x, y)
176+
dout = paddle.randn(out.shape)
177+
dout.stop_gradient = False
178+
dy = paddle.grad([out], [y], dout, create_graph=True)[0]
179+
ddout = paddle.grad(dy, dout)[0]
180+
np.testing.assert_allclose(ddout.numpy(), np.full(ddout.shape, -1.0))
181+
168182
def func_double(self, place):
169183
shape1 = self.shape1
170184
shape2 = self.shape2

0 commit comments

Comments
 (0)