Skip to content

Commit cf20920

Browse files
authored
add prim test for sqrt and exp (#50942)
1 parent 6786c01 commit cf20920

File tree

3 files changed

+125
-1
lines changed

3 files changed

+125
-1
lines changed

python/paddle/fluid/tests/unittests/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1203,7 +1203,7 @@ if($ENV{USE_STANDALONE_EXECUTOR})
12031203
endif()
12041204

12051205
set(TEST_CINN_OPS test_softmax_op test_expand_v2_op test_reduce_op
1206-
test_slice_op)
1206+
test_slice_op test_activation_op)
12071207

12081208
foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
12091209
if(WITH_CINN)

python/paddle/fluid/tests/unittests/prim_op_test.py

+10
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,11 @@ def init_checker_threshold(self):
378378
)
379379

380380
def check(self):
381+
if (
382+
self.place is paddle.fluid.libpaddle.CUDAPlace
383+
and not paddle.is_compiled_with_cuda()
384+
):
385+
return
381386
self.eager_desire = self.get_eager_desire()
382387
if self.enable_check_static_comp:
383388
self.check_static_comp()
@@ -773,6 +778,11 @@ def init(self):
773778
self.checker_name = "PrimGradChecker"
774779

775780
def check(self):
781+
if (
782+
self.place is paddle.fluid.libpaddle.CUDAPlace
783+
and not paddle.is_compiled_with_cuda()
784+
):
785+
return
776786
self.eager_desire = self.get_eager_desire()
777787
if self.enable_check_eager_comp:
778788
self.check_eager_comp()

python/paddle/fluid/tests/unittests/test_activation_op.py

+114
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,72 @@ def init_shape(self):
9090
self.shape = []
9191

9292

93+
class TestExpPrimFp32(OpTest):
94+
def setUp(self):
95+
self.op_type = "exp"
96+
self.prim_op_type = "prim"
97+
self.init_dtype()
98+
self.init_shape()
99+
self.python_api = paddle.exp
100+
101+
np.random.seed(2049)
102+
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
103+
out = np.exp(x)
104+
105+
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
106+
self.outputs = {'Out': out}
107+
self.skip_cinn()
108+
self.set_only_prim()
109+
110+
def test_check_output(self):
111+
self.check_output()
112+
113+
def test_check_grad(self):
114+
self.check_grad(['X'], 'Out', check_prim=True)
115+
116+
def init_dtype(self):
117+
self.dtype = np.float32
118+
119+
def init_shape(self):
120+
self.shape = [12, 17]
121+
122+
def skip_cinn(self):
123+
self.enable_cinn = False
124+
125+
def set_only_prim(self):
126+
pass
127+
128+
129+
class TestExpPrimFp64(TestExpPrimFp32):
130+
def init_dtype(self):
131+
self.dtype = np.float64
132+
133+
134+
class TestExpPrimFp16(TestExpPrimFp32):
135+
def init_dtype(self):
136+
self.dtype = np.float16
137+
138+
def set_only_prim(self):
139+
self.only_prim = True
140+
141+
def test_check_output(self):
142+
self.check_output()
143+
144+
def test_check_grad(self):
145+
self.check_grad(['X'], 'Out', check_prim=True)
146+
147+
def skip_cinn(self):
148+
self.enable_cinn = False
149+
150+
151+
class TestExpPrim_ZeroDim(TestExpPrimFp32):
152+
def init_shape(self):
153+
self.shape = []
154+
155+
def skip_cinn(self):
156+
self.enable_cinn = False
157+
158+
93159
class TestExpm1(TestActivation):
94160
def setUp(self):
95161
self.op_type = "expm1"
@@ -167,6 +233,8 @@ def test_errors(self):
167233
class TestParameter:
168234
def test_out_name(self):
169235
with fluid.program_guard(fluid.Program()):
236+
if paddle.fluid.framework.in_dygraph_mode():
237+
paddle.enable_static()
170238
np_x = np.array([0.1]).astype('float32').reshape((-1, 1))
171239
data = paddle.static.data(name="X", shape=[-1, 1], dtype="float32")
172240
out = eval("paddle.%s(data, name='Y')" % self.op_type)
@@ -1062,6 +1130,7 @@ def test_errors(self):
10621130
class TestSqrt(TestActivation, TestParameter):
10631131
def setUp(self):
10641132
self.op_type = "sqrt"
1133+
self.prim_op_type = "prim"
10651134
self.python_api = paddle.sqrt
10661135
self.init_dtype()
10671136
self.init_shape()
@@ -1072,7 +1141,9 @@ def setUp(self):
10721141

10731142
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
10741143
self.outputs = {'Out': out}
1144+
self.enable_cinn = False
10751145

1146+
# TODO(wanghao107) add prim test
10761147
def test_check_grad(self):
10771148
if self.dtype == np.float16:
10781149
return
@@ -1082,17 +1153,58 @@ def test_check_output(self):
10821153
self.check_output(check_eager=True)
10831154

10841155

1156+
class TestSqrtPrimFp32(TestActivation):
1157+
def setUp(self):
1158+
self.op_type = "sqrt"
1159+
self.prim_op_type = "prim"
1160+
self.python_api = paddle.sqrt
1161+
self.init_dtype()
1162+
self.init_shape()
1163+
np.random.seed(1023)
1164+
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
1165+
out = np.sqrt(x)
1166+
1167+
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
1168+
self.outputs = {'Out': out}
1169+
self.enable_cinn = False
1170+
1171+
def test_check_grad(self):
1172+
if self.dtype == np.float16:
1173+
return
1174+
self.check_grad(['X'], 'Out', check_eager=True, check_prim=True)
1175+
1176+
def test_check_output(self):
1177+
self.check_output(check_eager=True)
1178+
1179+
def init_dtype(self):
1180+
self.dtype = np.float32
1181+
1182+
10851183
class TestSqrt_ZeroDim(TestSqrt):
10861184
def init_shape(self):
10871185
self.shape = []
10881186

10891187

1188+
class TestSqrtPrim_ZeroDim(TestSqrt):
1189+
def init_shape(self):
1190+
self.shape = []
1191+
1192+
def init_dtype(self):
1193+
self.dtype = np.float32
1194+
1195+
def test_check_grad(self):
1196+
if self.dtype == np.float16:
1197+
return
1198+
self.check_grad(['X'], 'Out', check_prim=True)
1199+
1200+
10901201
@unittest.skipIf(
10911202
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
10921203
)
10931204
class TestSqrtBF16(OpTest):
10941205
def setUp(self):
10951206
self.op_type = "sqrt"
1207+
self.prim_op_type = "prim"
10961208
self.python_api = paddle.sqrt
10971209
self.init_dtype()
10981210
self.init_shape()
@@ -1105,6 +1217,8 @@ def setUp(self):
11051217
'X': OpTest.np_dtype_to_fluid_dtype(convert_float_to_uint16(x))
11061218
}
11071219
self.outputs = {'Out': convert_float_to_uint16(out)}
1220+
# TODO(wanghao107): add prim test
1221+
self.enable_cinn = False
11081222

11091223
def init_dtype(self):
11101224
self.dtype = np.uint16

0 commit comments

Comments
 (0)