Skip to content

Commit 13a5f18

Browse files
authored
[BugFix] while cond receives dict as input (#47299)
* fix bugs while cond receives dict as input * add unittest * change flatten -> _is_sequence_except_dict
1 parent ac3b882 commit 13a5f18

File tree

2 files changed

+86
-27
lines changed

2 files changed

+86
-27
lines changed

python/paddle/fluid/layers/control_flow.py

+47-27
Original file line numberDiff line numberDiff line change
@@ -2839,7 +2839,7 @@ def false_func():
28392839
# Merge ture and false output if they are not None
28402840
if return_names is None:
28412841
is_dy2staic = False
2842-
return_names = ["no name"] * len(to_sequence(true_output))
2842+
return_names = ["no name"] * len(_to_sequence_except_dict(true_output))
28432843
else:
28442844
"""
28452845
dy2static will set the return_names and expand the return values to UndefinedVar.
@@ -2855,16 +2855,19 @@ def false_func():
28552855
true_output, false_output, return_names
28562856
)
28572857

2858-
if len(to_sequence(true_output)) != len(to_sequence(false_output)):
2858+
if len(_to_sequence_except_dict(true_output)) != len(
2859+
_to_sequence_except_dict(false_output)
2860+
):
28592861
raise ValueError(
28602862
"true fn returns {} vars, but false fn returns {} vars, which is not equals".format(
2861-
len(to_sequence(true_output)), len(to_sequence(false_output))
2863+
len(_to_sequence_except_dict(true_output)),
2864+
len(_to_sequence_except_dict(false_output)),
28622865
)
28632866
)
28642867
for true_out, false_out, return_name in zip(
2865-
to_sequence(true_output),
2866-
to_sequence(false_output),
2867-
to_sequence(return_names),
2868+
_to_sequence_except_dict(true_output),
2869+
_to_sequence_except_dict(false_output),
2870+
_to_sequence_except_dict(return_names),
28682871
):
28692872
try:
28702873
assert_same_structure(true_out, false_out, check_types=False)
@@ -2876,10 +2879,9 @@ def false_func():
28762879
)
28772880

28782881
def check_ret_none(seq_true, seq_false, seq_names):
2879-
length = len(seq_true)
2880-
for i in range(length):
2881-
f_true = flatten(seq_true[i])
2882-
f_false = flatten(seq_false[i])
2882+
for f_true, f_false, f_name in zip(seq_true, seq_false, seq_names):
2883+
f_true = flatten(f_true)
2884+
f_false = flatten(f_false)
28832885
for idx in range(len(f_true)):
28842886
if (
28852887
f_true[idx] is None
@@ -2891,7 +2893,7 @@ def check_ret_none(seq_true, seq_false, seq_names):
28912893
"In cond : Var '{}' or part of it is set differently in ifelse branchs, "
28922894
"<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
28932895
"'None' in ifelse block might lead to error.".format(
2894-
seq_names[i],
2896+
f_name,
28952897
type(f_true[idx]),
28962898
f_true[idx],
28972899
type(f_false[idx]),
@@ -2900,9 +2902,9 @@ def check_ret_none(seq_true, seq_false, seq_names):
29002902
)
29012903

29022904
check_ret_none(
2903-
to_sequence(true_output),
2904-
to_sequence(false_output),
2905-
to_sequence(return_names),
2905+
_to_sequence_except_dict(true_output),
2906+
_to_sequence_except_dict(false_output),
2907+
_to_sequence_except_dict(return_names),
29062908
)
29072909

29082910
if is_dy2staic:
@@ -2923,9 +2925,9 @@ def merge_every_var_list(false_vars, true_vars, name):
29232925
merged_output = list(
29242926
map(
29252927
merge_every_var_list,
2926-
to_sequence(false_output),
2927-
to_sequence(true_output),
2928-
to_sequence(return_names),
2928+
_to_sequence_except_dict(false_output),
2929+
_to_sequence_except_dict(true_output),
2930+
_to_sequence_except_dict(return_names),
29292931
)
29302932
)
29312933
merged_output = pack_sequence_as(false_output, flatten(merged_output))
@@ -2945,6 +2947,24 @@ def map_fn(x):
29452947
return nest1_out, nest2_out
29462948

29472949

2950+
def _to_sequence_except_dict(x):
2951+
"""
2952+
In this function, dict is not viewed as sequence.
2953+
"""
2954+
if isinstance(x, dict):
2955+
return [x]
2956+
return to_sequence(x)
2957+
2958+
2959+
def _is_sequence_except_dict(x):
2960+
"""
2961+
In this function, dict is not viewed as sequence.
2962+
"""
2963+
if isinstance(x, dict):
2964+
return False
2965+
return is_sequence(x)
2966+
2967+
29482968
def expand_undefined_var(nest1, nest2, names):
29492969
"""TODO: make this function recursively.
29502970
nest1: Var1, (UndefinedVar, [1,2,3])
@@ -2988,24 +3008,24 @@ def map_fn(n1, n2, name, order):
29883008
nest1_out = list(
29893009
map(
29903010
map_fn,
2991-
to_sequence(nest1),
2992-
to_sequence(nest2),
2993-
to_sequence(names),
2994-
[0 for i in to_sequence(names)],
3011+
_to_sequence_except_dict(nest1),
3012+
_to_sequence_except_dict(nest2),
3013+
_to_sequence_except_dict(names),
3014+
[0 for i in _to_sequence_except_dict(names)],
29953015
)
29963016
)
29973017
nest2_out = list(
29983018
map(
29993019
map_fn,
3000-
to_sequence(nest2),
3001-
to_sequence(nest1),
3002-
to_sequence(names),
3003-
[1 for i in to_sequence(names)],
3020+
_to_sequence_except_dict(nest2),
3021+
_to_sequence_except_dict(nest1),
3022+
_to_sequence_except_dict(names),
3023+
[1 for i in _to_sequence_except_dict(names)],
30043024
)
30053025
)
3006-
if not is_sequence(nest1):
3026+
if not _is_sequence_except_dict(nest1):
30073027
nest1_out = nest1_out[0]
3008-
if not is_sequence(nest2):
3028+
if not _is_sequence_except_dict(nest2):
30093029
nest2_out = nest2_out[0]
30103030
return nest1_out, nest2_out
30113031

python/paddle/fluid/tests/unittests/test_cond.py

+39
Original file line numberDiff line numberDiff line change
@@ -676,5 +676,44 @@ def func():
676676
layers.cond(pred, func, func, set())
677677

678678

679+
class TestCondWithDict(unittest.TestCase):
680+
def test_input_with_dict(self):
681+
paddle.enable_static()
682+
main_program = framework.Program()
683+
startup_program = framework.Program()
684+
with framework.program_guard(main_program, startup_program):
685+
686+
def true_func():
687+
return {
688+
'1': paddle.full(shape=[3, 2], dtype='int32', fill_value=1),
689+
'2': paddle.full(
690+
shape=[2, 3], dtype='bool', fill_value=True
691+
),
692+
}
693+
694+
def false_func():
695+
return {
696+
'1': paddle.full(
697+
shape=[3, 4], dtype='float32', fill_value=3
698+
),
699+
'2': paddle.full(shape=[4, 5], dtype='int64', fill_value=2),
700+
}
701+
702+
x = paddle.full(shape=[1], dtype='float32', fill_value=0.1)
703+
y = paddle.full(shape=[1], dtype='float32', fill_value=0.23)
704+
pred = paddle.less_than(x=x, y=y, name=None)
705+
ret = paddle.static.nn.cond(pred, true_func, false_func)
706+
self.assertEqual(
707+
ret['1'].shape,
708+
(3, -1),
709+
f"The shape is not correct, expects (3, -1) but gets {ret['1'].shape}.",
710+
)
711+
self.assertEqual(
712+
ret['2'].shape,
713+
(-1, -1),
714+
f"The shape is not correct, expects (-1, -1) but gets {ret['2'].shape}.",
715+
)
716+
717+
679718
if __name__ == '__main__':
680719
unittest.main()

0 commit comments

Comments
 (0)