Skip to content

Commit a0517fa

Browse files
committed
add new test
1 parent 063476a commit a0517fa

File tree

2 files changed

+99
-18
lines changed

2 files changed

+99
-18
lines changed

test/ir/pir/cinn/symbolic/test_if_dy.py

+7-18
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,11 @@ def exp_sub(self, x):
3535
y = paddle.exp(x)
3636
return y - x
3737

38-
def forward(self, x, y):
38+
def forward(self, x):
3939
if x.shape[-1] > 1:
4040
x = self.exp_sub(x)
41-
else:
42-
y = paddle.abs(y)
4341
x = paddle.nn.functional.relu(x)
44-
y = paddle.logical_not(y)
45-
return x, y
42+
return x
4643

4744

4845
class TestIfSubgraph(unittest.TestCase):
@@ -55,36 +52,28 @@ def prepare_data(self):
5552
self.x = paddle.randn(self.shape, dtype="float32")
5653
self.x.stop_gradient = False
5754

58-
self.y_shape = [2, 256]
59-
self.y = paddle.randn(self.y_shape, dtype="float32")
60-
self.y.stop_gradient = False
61-
6255
def check_jit_kernel_info(self, static_fn):
6356
utils.check_jit_kernel_number(static_fn, 1)
6457
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1})
6558

6659
def eval(self, use_cinn):
6760
net = IfSubgraph()
6861
input_spec = [
69-
InputSpec(shape=[None, None], dtype="float32"),
70-
InputSpec(shape=[None, None], dtype="float32"),
62+
InputSpec(shape=[None, None], dtype="bool"),
7163
]
7264
net = utils.apply_to_static(net, use_cinn, input_spec)
7365
net.eval()
74-
out = net(self.x, self.y)
66+
out = net(self.x)
7567
if use_cinn:
7668
self.check_jit_kernel_info(net.forward)
7769
return out
7870

7971
def test_eval(self):
80-
dy_out_x, dy_out_y = self.eval(use_cinn=False)
72+
dy_out = self.eval(use_cinn=False)
8173
if utils.unittest_use_cinn():
82-
cinn_out_x, cinn_out_y = self.eval(use_cinn=True)
83-
np.testing.assert_allclose(
84-
cinn_out_x.numpy(), dy_out_x.numpy(), atol=1e-6, rtol=1e-6
85-
)
74+
cinn_out = self.eval(use_cinn=True)
8675
np.testing.assert_allclose(
87-
cinn_out_y.numpy(), dy_out_y.numpy(), atol=1e-6, rtol=1e-6
76+
cinn_out.numpy(), dy_out.numpy(), atol=1e-6, rtol=1e-6
8877
)
8978

9079

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import sys
16+
import unittest
17+
from os.path import dirname
18+
19+
import numpy as np
20+
21+
import paddle
22+
from paddle import nn
23+
from paddle.static import InputSpec
24+
25+
sys.path.append(dirname(dirname(__file__)))
26+
27+
import utils
28+
29+
30+
class IfSubgraph(nn.Layer):
31+
def __init__(self):
32+
super().__init__()
33+
34+
def exp_sub(self, x):
35+
y = paddle.exp(x)
36+
return y - x
37+
38+
def forward(self, x, y):
39+
if x.shape[-1] > 1:
40+
x = self.exp_sub(x)
41+
else:
42+
y = paddle.abs(y)
43+
x = paddle.nn.functional.relu(x)
44+
y = paddle.logical_not(y)
45+
return x, y
46+
47+
48+
class TestIfSubgraph(unittest.TestCase):
49+
def setUp(self):
50+
paddle.seed(2024)
51+
self.prepare_data()
52+
53+
def prepare_data(self):
54+
self.shape = [1, 2048]
55+
self.x = paddle.randn(self.shape, dtype="float32")
56+
self.x.stop_gradient = False
57+
58+
self.y_shape = [2, 256]
59+
self.y = paddle.randn(self.y_shape, dtype="float32")
60+
self.y.stop_gradient = False
61+
62+
def check_jit_kernel_info(self, static_fn):
63+
utils.check_jit_kernel_number(static_fn, 1)
64+
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1})
65+
66+
def eval(self, use_cinn):
67+
net = IfSubgraph()
68+
input_spec = [
69+
InputSpec(shape=[None, None], dtype="float32"),
70+
InputSpec(shape=[None, None], dtype="float32"),
71+
]
72+
net = utils.apply_to_static(net, use_cinn, input_spec)
73+
net.eval()
74+
out = net(self.x, self.y)
75+
if use_cinn:
76+
self.check_jit_kernel_info(net.forward)
77+
return out
78+
79+
def test_eval(self):
80+
dy_out_x, dy_out_y = self.eval(use_cinn=False)
81+
if utils.unittest_use_cinn():
82+
cinn_out_x, cinn_out_y = self.eval(use_cinn=True)
83+
np.testing.assert_allclose(
84+
cinn_out_x.numpy(), dy_out_x.numpy(), atol=1e-6, rtol=1e-6
85+
)
86+
np.testing.assert_allclose(
87+
cinn_out_y.numpy(), dy_out_y.numpy(), atol=1e-6, rtol=1e-6
88+
)
89+
90+
91+
if __name__ == '__main__':
92+
unittest.main()

0 commit comments

Comments
 (0)