@@ -718,68 +718,80 @@ def symbolic_call(
718
718
func : the logical function which will be represent as a stmt
719
719
"""
720
720
721
- def try_infer_meta_fn (args , kwargs ) -> Any :
722
- try :
723
- metas = convert_to_meta (args )
724
- kwmetas = convert_to_meta (kwargs )
725
- return args , kwargs , infer_meta_fn (func , * metas , ** kwmetas )
726
- except (NotSupportedTensorArgumentError , TypeError ) as e :
727
- bound_arguments = inspect .signature (func ).bind (* args , ** kwargs )
728
- bound_arguments .apply_defaults ()
729
- if (
730
- isinstance (e , NotSupportedTensorArgumentError )
731
- and e .name in bound_arguments .arguments
721
+ def infer_meta (args , kwargs ):
722
+ metas = convert_to_meta (args )
723
+ kwmetas = convert_to_meta (kwargs )
724
+ return infer_meta_fn (func , * metas , ** kwmetas )
725
+
726
+ def fallback_symbolic_to_constant (args , kwargs , err ):
727
+ bound_arguments = inspect .signature (func ).bind (* args , ** kwargs )
728
+ bound_arguments .apply_defaults ()
729
+ if (
730
+ isinstance (err , NotSupportedTensorArgumentError )
731
+ and err .name in bound_arguments .arguments
732
+ ):
733
+ original_var = bound_arguments .arguments [err .name ]
734
+ flatten_vars = original_var .flatten_inner_vars ()
735
+ if not any (
736
+ isinstance (arg , SymbolicVariable ) for arg in flatten_vars
732
737
):
733
- original_var = bound_arguments .arguments [e .name ]
734
- flatten_vars = original_var .flatten_inner_vars ()
735
- if not any (
736
- isinstance (arg , SymbolicVariable )
737
- for arg in flatten_vars
738
- ):
739
- # TODO(zrr1999): maybe we can continue to fallback to all args are constant.
740
- raise BreakGraphError (
741
- InferMetaBreak (
742
- f"InferMeta encount { type (e )} , but all args are not symbolic."
743
- )
738
+ # TODO(zrr1999): maybe we can continue to fallback to all args are constant.
739
+ raise BreakGraphError (
740
+ InferMetaBreak (
741
+ f"InferMeta encountered { type (err )} , but all args are not symbolic."
744
742
)
745
-
746
- args , kwargs = map_if (
747
- (args , kwargs ),
748
- pred = lambda x : x is original_var ,
749
- true_fn = lambda x : replace_symbolic_var_with_constant_var (
750
- x
751
- ),
752
- false_fn = lambda x : x ,
753
- )
754
- else :
755
- flatten_vars = reduce (
756
- lambda x , y : (
757
- x + y .flatten_inner_vars ()
758
- if isinstance (y , VariableBase )
759
- else x
760
- ),
761
- bound_arguments .arguments .values (),
762
- [],
763
743
)
764
744
765
- if not any (
766
- isinstance (arg , SymbolicVariable )
767
- for arg in flatten_vars
768
- ):
769
- raise BreakGraphError (
770
- InferMetaBreak (
771
- f"InferMeta encount { type (e )} , but all args are not symbolic."
772
- )
773
- )
745
+ args , kwargs = map_if (
746
+ (args , kwargs ),
747
+ pred = lambda x : x is original_var ,
748
+ true_fn = lambda x : replace_symbolic_var_with_constant_var (x ),
749
+ false_fn = lambda x : x ,
750
+ )
751
+ else :
752
+ flatten_vars = reduce (
753
+ lambda x , y : (
754
+ x + y .flatten_inner_vars ()
755
+ if isinstance (y , VariableBase )
756
+ else x
757
+ ),
758
+ bound_arguments .arguments .values (),
759
+ [],
760
+ )
774
761
775
- args , kwargs = map_structure (
776
- replace_symbolic_var_with_constant_var , (args , kwargs )
762
+ if not any (
763
+ isinstance (arg , SymbolicVariable ) for arg in flatten_vars
764
+ ):
765
+ raise BreakGraphError (
766
+ InferMetaBreak (
767
+ f"InferMeta encountered { type (err )} , but all args are not symbolic."
768
+ )
777
769
)
778
770
779
- metas = convert_to_meta (args )
780
- kwmetas = convert_to_meta (kwargs )
781
- return args , kwargs , infer_meta_fn (func , * metas , ** kwmetas )
771
+ args , kwargs = map_structure (
772
+ replace_symbolic_var_with_constant_var , (args , kwargs )
773
+ )
774
+ return args , kwargs
782
775
776
+ def try_infer_meta_with_fallback_symbolic_to_constant (
777
+ args , kwargs , max_retry_times = 10
778
+ ):
779
+ try :
780
+ return args , kwargs , infer_meta (args , kwargs )
781
+ except (NotSupportedTensorArgumentError , TypeError ) as e :
782
+ err = e
783
+ retry_times = 0
784
+ while True :
785
+ retry_times += 1
786
+ if retry_times >= max_retry_times :
787
+ raise err
788
+ try :
789
+ args , kwargs = fallback_symbolic_to_constant (
790
+ args , kwargs , err
791
+ )
792
+ return args , kwargs , infer_meta (args , kwargs )
793
+ except (NotSupportedTensorArgumentError , TypeError ) as e :
794
+ err = e
783
795
except Exception as e :
784
796
if SotExtraInfo .from_exception (e ).need_breakgraph :
785
797
raise BreakGraphError (
@@ -790,11 +802,11 @@ def try_infer_meta_fn(args, kwargs) -> Any:
790
802
raise e
791
803
792
804
if ENV_SOT_ALLOW_DYNAMIC_SHAPE .get ():
793
- args , kwargs , out_metas = try_infer_meta_fn (args , kwargs )
805
+ args , kwargs , out_metas = (
806
+ try_infer_meta_with_fallback_symbolic_to_constant (args , kwargs )
807
+ )
794
808
else :
795
- metas = convert_to_meta (args )
796
- kwmetas = convert_to_meta (kwargs )
797
- out_metas = infer_meta_fn (func , * metas , ** kwmetas )
809
+ out_metas = infer_meta (args , kwargs )
798
810
799
811
self .collect_input_variables (list (args ))
800
812
self .collect_input_variables (list (kwargs .values ()))
0 commit comments