Skip to content

Commit 10ec3ed

Browse files
authored
[MLU-FIX] fix(nonzero): add sync and fix its test (PaddlePaddle#391)
1 parent 0b11184 commit 10ec3ed

File tree

2 files changed

+5
-12
lines changed

2 files changed

+5
-12
lines changed

backends/mlu/kernels/nonzero_kernel.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ namespace custom_kernel {
2121

2222
template <typename T, typename Context>
2323
void NonZeroKernel(const Context& dev_ctx,
24-
const phi::DenseTensor& condition,
25-
phi::DenseTensor* out) {
24+
const phi::DenseTensor& condition,
25+
phi::DenseTensor* out) {
2626
auto dims = condition.dims();
2727
const int rank = dims.size();
2828

@@ -38,6 +38,7 @@ void NonZeroKernel(const Context& dev_ctx,
3838
GetBasePtr(&num_true));
3939

4040
Tensor local_true_num;
41+
dev_ctx.Wait(); // add sync for fully calculated results
4142
TensorCopy(dev_ctx, num_true, true, &local_true_num, phi::CPUPlace());
4243
auto true_num = *local_true_num.data<int>();
4344

backends/mlu/tests/unittests/test_where_index_op_mlu.py

100644100755
Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,21 +107,13 @@ def test_api(self):
107107
paddle.set_device("CustomMLU")
108108
with program_guard(Program(), Program()):
109109
cond = paddle.static.data(name="cond", shape=[-1, 4], dtype="bool")
110-
result = fluid.layers.where(cond)
110+
result = paddle.nonzero(cond)
111+
111112
exe = fluid.Executor(paddle.CustomPlace("CustomMLU", 0))
112113
exe.run(fluid.default_startup_program())
113114
cond_i = np.array([True, False, False, False]).astype("bool")
114115
out = exe.run(fluid.default_main_program(), feed={"cond": cond_i})
115116

116117

117-
class TestWhereRaiseError(unittest.TestCase):
118-
def test_errors(self):
119-
def test_type():
120-
paddle.set_device("CustomMLU")
121-
fluid.layers.where([10])
122-
123-
self.assertRaises(TypeError, test_type)
124-
125-
126118
if __name__ == "__main__":
127119
unittest.main()

0 commit comments

Comments
 (0)