@@ -2839,7 +2839,7 @@ def false_func():
2839
2839
# Merge ture and false output if they are not None
2840
2840
if return_names is None :
2841
2841
is_dy2staic = False
2842
- return_names = ["no name" ] * len (to_sequence (true_output ))
2842
+ return_names = ["no name" ] * len (_to_sequence_except_dict (true_output ))
2843
2843
else :
2844
2844
"""
2845
2845
dy2static will set the return_names and expand the return values to UndefinedVar.
@@ -2855,16 +2855,19 @@ def false_func():
2855
2855
true_output , false_output , return_names
2856
2856
)
2857
2857
2858
- if len (to_sequence (true_output )) != len (to_sequence (false_output )):
2858
+ if len (_to_sequence_except_dict (true_output )) != len (
2859
+ _to_sequence_except_dict (false_output )
2860
+ ):
2859
2861
raise ValueError (
2860
2862
"true fn returns {} vars, but false fn returns {} vars, which is not equals" .format (
2861
- len (to_sequence (true_output )), len (to_sequence (false_output ))
2863
+ len (_to_sequence_except_dict (true_output )),
2864
+ len (_to_sequence_except_dict (false_output )),
2862
2865
)
2863
2866
)
2864
2867
for true_out , false_out , return_name in zip (
2865
- to_sequence (true_output ),
2866
- to_sequence (false_output ),
2867
- to_sequence (return_names ),
2868
+ _to_sequence_except_dict (true_output ),
2869
+ _to_sequence_except_dict (false_output ),
2870
+ _to_sequence_except_dict (return_names ),
2868
2871
):
2869
2872
try :
2870
2873
assert_same_structure (true_out , false_out , check_types = False )
@@ -2876,10 +2879,9 @@ def false_func():
2876
2879
)
2877
2880
2878
2881
def check_ret_none (seq_true , seq_false , seq_names ):
2879
- length = len (seq_true )
2880
- for i in range (length ):
2881
- f_true = flatten (seq_true [i ])
2882
- f_false = flatten (seq_false [i ])
2882
+ for f_true , f_false , f_name in zip (seq_true , seq_false , seq_names ):
2883
+ f_true = flatten (f_true )
2884
+ f_false = flatten (f_false )
2883
2885
for idx in range (len (f_true )):
2884
2886
if (
2885
2887
f_true [idx ] is None
@@ -2891,7 +2893,7 @@ def check_ret_none(seq_true, seq_false, seq_names):
2891
2893
"In cond : Var '{}' or part of it is set differently in ifelse branchs, "
2892
2894
"<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
2893
2895
"'None' in ifelse block might lead to error." .format (
2894
- seq_names [ i ] ,
2896
+ f_name ,
2895
2897
type (f_true [idx ]),
2896
2898
f_true [idx ],
2897
2899
type (f_false [idx ]),
@@ -2900,9 +2902,9 @@ def check_ret_none(seq_true, seq_false, seq_names):
2900
2902
)
2901
2903
2902
2904
check_ret_none (
2903
- to_sequence (true_output ),
2904
- to_sequence (false_output ),
2905
- to_sequence (return_names ),
2905
+ _to_sequence_except_dict (true_output ),
2906
+ _to_sequence_except_dict (false_output ),
2907
+ _to_sequence_except_dict (return_names ),
2906
2908
)
2907
2909
2908
2910
if is_dy2staic :
@@ -2923,9 +2925,9 @@ def merge_every_var_list(false_vars, true_vars, name):
2923
2925
merged_output = list (
2924
2926
map (
2925
2927
merge_every_var_list ,
2926
- to_sequence (false_output ),
2927
- to_sequence (true_output ),
2928
- to_sequence (return_names ),
2928
+ _to_sequence_except_dict (false_output ),
2929
+ _to_sequence_except_dict (true_output ),
2930
+ _to_sequence_except_dict (return_names ),
2929
2931
)
2930
2932
)
2931
2933
merged_output = pack_sequence_as (false_output , flatten (merged_output ))
@@ -2945,6 +2947,24 @@ def map_fn(x):
2945
2947
return nest1_out , nest2_out
2946
2948
2947
2949
2950
+ def _to_sequence_except_dict (x ):
2951
+ """
2952
+ In this function, dict is not viewed as sequence.
2953
+ """
2954
+ if isinstance (x , dict ):
2955
+ return [x ]
2956
+ return to_sequence (x )
2957
+
2958
+
2959
+ def _is_sequence_except_dict (x ):
2960
+ """
2961
+ In this function, dict is not viewed as sequence.
2962
+ """
2963
+ if isinstance (x , dict ):
2964
+ return False
2965
+ return is_sequence (x )
2966
+
2967
+
2948
2968
def expand_undefined_var (nest1 , nest2 , names ):
2949
2969
"""TODO: make this function recursively.
2950
2970
nest1: Var1, (UndefinedVar, [1,2,3])
@@ -2988,24 +3008,24 @@ def map_fn(n1, n2, name, order):
2988
3008
nest1_out = list (
2989
3009
map (
2990
3010
map_fn ,
2991
- to_sequence (nest1 ),
2992
- to_sequence (nest2 ),
2993
- to_sequence (names ),
2994
- [0 for i in to_sequence (names )],
3011
+ _to_sequence_except_dict (nest1 ),
3012
+ _to_sequence_except_dict (nest2 ),
3013
+ _to_sequence_except_dict (names ),
3014
+ [0 for i in _to_sequence_except_dict (names )],
2995
3015
)
2996
3016
)
2997
3017
nest2_out = list (
2998
3018
map (
2999
3019
map_fn ,
3000
- to_sequence (nest2 ),
3001
- to_sequence (nest1 ),
3002
- to_sequence (names ),
3003
- [1 for i in to_sequence (names )],
3020
+ _to_sequence_except_dict (nest2 ),
3021
+ _to_sequence_except_dict (nest1 ),
3022
+ _to_sequence_except_dict (names ),
3023
+ [1 for i in _to_sequence_except_dict (names )],
3004
3024
)
3005
3025
)
3006
- if not is_sequence (nest1 ):
3026
+ if not _is_sequence_except_dict (nest1 ):
3007
3027
nest1_out = nest1_out [0 ]
3008
- if not is_sequence (nest2 ):
3028
+ if not _is_sequence_except_dict (nest2 ):
3009
3029
nest2_out = nest2_out [0 ]
3010
3030
return nest1_out , nest2_out
3011
3031
0 commit comments