Skip to content

Commit 8f5eae4

Browse files
authored
[Bug fixes] Fix bugs in some sparse test (#53428)
1 parent 4ccbcce commit 8f5eae4

File tree

3 files changed

+6
-0
lines changed

3 files changed

+6
-0
lines changed

python/paddle/fluid/tests/unittests/test_sparse_reshape_op.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def check_result(self, x_shape, new_shape, format):
3636
paddle.sparse.reshape.
3737
"""
3838
mask = np.random.randint(0, 2, x_shape)
39+
while np.sum(mask) == 0:
40+
mask = paddle.randint(0, 2, x_shape)
3941
np_x = np.random.randint(-100, 100, x_shape) * mask
4042

4143
# check cpu kernel

python/paddle/fluid/tests/unittests/test_sparse_transpose_op.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ class TestTranspose(unittest.TestCase):
2323
# x: sparse, out: sparse
2424
def check_result(self, x_shape, dims, format):
2525
mask = paddle.randint(0, 2, x_shape).astype("float32")
26+
while paddle.sum(mask) == 0:
27+
mask = paddle.randint(0, 2, x_shape).astype("float32")
2628
# "+ 1" to make sure that all zero elements in "origin_x" is caused by multiplying by "mask",
2729
# or the backward checks may fail.
2830
origin_x = (paddle.rand(x_shape, dtype='float32') + 1) * mask

python/paddle/fluid/tests/unittests/test_sparse_unary_op.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def to_sparse(self, x, format):
3030
def check_result(self, dense_func, sparse_func, format, *args):
3131
origin_x = paddle.rand([8, 16, 32], dtype='float32')
3232
mask = paddle.randint(0, 2, [8, 16, 32]).astype('float32')
33+
while paddle.sum(mask) == 0:
34+
mask = paddle.randint(0, 2, [8, 16, 32]).astype("float32")
3335

3436
# --- check sparse coo with dense --- #
3537
dense_x = origin_x * mask

0 commit comments

Comments
 (0)