Skip to content

Commit c91cfdc

Browse files
committed
Fix
1 parent f1b794d commit c91cfdc

File tree

3 files changed

+31
-1
lines changed

3 files changed

+31
-1
lines changed

test/legacy_test/test_activation_op_zero_size.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,13 @@
2525
TestCeil,
2626
TestCos,
2727
TestCosh,
28+
TestExpFp32_Prim,
29+
TestExpm1,
2830
TestFloor,
2931
TestLogSigmoid,
3032
TestReciprocal,
3133
TestRelu,
34+
TestRelu6,
3235
TestRsqrt,
3336
TestSigmoid,
3437
TestSilu,
@@ -96,7 +99,9 @@ def test_check_grad(self):
9699
create_test_zero_size_class(TestLogSigmoid)
97100
create_test_zero_size_class(TestFloor)
98101
create_test_zero_size_class(TestCeil)
99-
102+
create_test_zero_size_class(TestExpFp32_Prim)
103+
create_test_zero_size_class(TestExpm1)
104+
create_test_zero_size_class(TestRelu6)
100105

101106
if __name__ == "__main__":
102107
unittest.main()

test/legacy_test/test_bitwise_op.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,11 @@ def init_shape(self):
392392
self.x_shape = []
393393

394394

395+
class TestBitwiseNot_ZeroSize(TestBitwiseNot):
396+
def init_shape(self):
397+
self.x_shape = [0, 3, 4, 5]
398+
399+
395400
class TestBitwiseNotUInt8(TestBitwiseNot):
396401
def init_dtype(self):
397402
self.dtype = np.uint8

test/legacy_test/test_polygamma_op.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,5 +217,25 @@ def test_check_grad(self):
217217
)
218218

219219

220+
class TestPolygammaOp_ZeroSize(TestPolygammaOp):
221+
222+
def init_config(self):
223+
self.dtype = np.float64
224+
self.order = 1
225+
rand_case = np.random.randn(0).astype(self.dtype)
226+
int_case = np.random.randint(low=1, high=100, size=0).astype(self.dtype)
227+
self.case = np.concatenate([rand_case, int_case])
228+
self.inputs = {'x': self.case}
229+
self.attrs = {'n': self.order}
230+
self.target = ref_polygamma(self.inputs['x'], self.order)
231+
232+
def test_check_grad(self):
233+
self.check_grad(
234+
['x'],
235+
'out',
236+
check_pir=True,
237+
)
238+
239+
220240
if __name__ == "__main__":
221241
unittest.main()

0 commit comments

Comments
 (0)