From 200f92158c4ffb6c66a73ae81de657d9890b9f0f Mon Sep 17 00:00:00 2001 From: RedContritio Date: Sat, 21 Jan 2023 16:29:56 +0000 Subject: [PATCH 1/2] fix incorrect output shape of broadcast --- paddle/phi/infermeta/multiary.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 6b238209d4ac2f..6a16003806b671 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -791,7 +791,7 @@ void BroadcastTensorsInferMeta(const std::vector& x, // We performed bcast semantics check at python level // So input tensors should all have legal shape - target_dim_size = std::max(target_dim_size, dim_size); + target_dim_size = dim_size == 1 ? target_dim_size : dim_size; } target_dims[target_rank - index - 1] = target_dim_size; } From 454b9a4e7835a61e55187784452b629a9d0a3d29 Mon Sep 17 00:00:00 2001 From: RedContritio Date: Fri, 27 Jan 2023 11:55:21 +0000 Subject: [PATCH 2/2] add unittest --- .../tests/unittests/test_broadcast_tensors_op.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_broadcast_tensors_op.py b/python/paddle/fluid/tests/unittests/test_broadcast_tensors_op.py index 6eec711c49e0ab..9879aac254fb70 100644 --- a/python/paddle/fluid/tests/unittests/test_broadcast_tensors_op.py +++ b/python/paddle/fluid/tests/unittests/test_broadcast_tensors_op.py @@ -33,14 +33,12 @@ def find_output_shape(input_list): rank = len(x.shape) output_rank = max(output_rank, rank) - output_shape = [0 for i in range(output_rank)] + output_shape = [1 for i in range(output_rank)] for i in range(output_rank): for x in input_list: shape = list(reversed(x.shape)) - size = 1 - if i < len(shape): - size = shape[i] - output_shape[i] = max(output_shape[i], size) + if i < len(shape) and shape[i] != 1: + output_shape[i] = shape[i] return list(reversed(output_shape)) @@ -80,6 +78,11 @@ def gen_mixed_tensors_test(dtype): return make_inputs_outputs(input_shapes, dtype) +def gen_empty_tensors_test(dtype): + input_shapes = [(0), (0), (0)] + return make_inputs_outputs(input_shapes, dtype) + + class TestCPUBroadcastTensorsOp(OpTest): def set_place(self): self.place = core.CPUPlace() @@ -95,6 +98,7 @@ def setUp(self): gen_rank_diff_test, gen_no_broadcast_test, gen_mixed_tensors_test, + gen_empty_tensors_test, ] self.set_place() self.set_dtypes()