Skip to content

Commit dc813e5

Browse files
SigureMoCopilot
andauthored
[SOT][DynamicShape][PHI] Fallback symbolic variable until success when infermeta and fix unfold infermeta (#72704)
--------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent ce0a5f5 commit dc813e5

File tree

3 files changed

+77
-62
lines changed

3 files changed

+77
-62
lines changed

paddle/phi/infermeta/unary.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5583,7 +5583,9 @@ void UnfoldInferMeta(const MetaTensor& x,
55835583
std::vector<int> out_dims;
55845584
out_dims.push_back(in_dims[0]); // NOLINT
55855585
int output_channels =
5586-
static_cast<int>(in_dims[1] * kernel_sizes[0] * kernel_sizes[1]);
5586+
in_dims[1] < 0
5587+
? -1
5588+
: static_cast<int>(in_dims[1] * kernel_sizes[0] * kernel_sizes[1]);
55875589
out_dims.push_back(output_channels);
55885590

55895591
int output_height = phi::funcs::CalcOutputSize(static_cast<int>(in_dims[2]),

python/paddle/jit/sot/opcode_translator/executor/function_graph.py

Lines changed: 70 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -718,68 +718,80 @@ def symbolic_call(
718718
func : the logical function which will be represent as a stmt
719719
"""
720720

721-
def try_infer_meta_fn(args, kwargs) -> Any:
722-
try:
723-
metas = convert_to_meta(args)
724-
kwmetas = convert_to_meta(kwargs)
725-
return args, kwargs, infer_meta_fn(func, *metas, **kwmetas)
726-
except (NotSupportedTensorArgumentError, TypeError) as e:
727-
bound_arguments = inspect.signature(func).bind(*args, **kwargs)
728-
bound_arguments.apply_defaults()
729-
if (
730-
isinstance(e, NotSupportedTensorArgumentError)
731-
and e.name in bound_arguments.arguments
721+
def infer_meta(args, kwargs):
722+
metas = convert_to_meta(args)
723+
kwmetas = convert_to_meta(kwargs)
724+
return infer_meta_fn(func, *metas, **kwmetas)
725+
726+
def fallback_symbolic_to_constant(args, kwargs, err):
727+
bound_arguments = inspect.signature(func).bind(*args, **kwargs)
728+
bound_arguments.apply_defaults()
729+
if (
730+
isinstance(err, NotSupportedTensorArgumentError)
731+
and err.name in bound_arguments.arguments
732+
):
733+
original_var = bound_arguments.arguments[err.name]
734+
flatten_vars = original_var.flatten_inner_vars()
735+
if not any(
736+
isinstance(arg, SymbolicVariable) for arg in flatten_vars
732737
):
733-
original_var = bound_arguments.arguments[e.name]
734-
flatten_vars = original_var.flatten_inner_vars()
735-
if not any(
736-
isinstance(arg, SymbolicVariable)
737-
for arg in flatten_vars
738-
):
739-
# TODO(zrr1999): maybe we can continue to fallback to all args are constant.
740-
raise BreakGraphError(
741-
InferMetaBreak(
742-
f"InferMeta encount {type(e)}, but all args are not symbolic."
743-
)
738+
# TODO(zrr1999): maybe we can continue to fallback to all args are constant.
739+
raise BreakGraphError(
740+
InferMetaBreak(
741+
f"InferMeta encountered {type(err)}, but all args are not symbolic."
744742
)
745-
746-
args, kwargs = map_if(
747-
(args, kwargs),
748-
pred=lambda x: x is original_var,
749-
true_fn=lambda x: replace_symbolic_var_with_constant_var(
750-
x
751-
),
752-
false_fn=lambda x: x,
753-
)
754-
else:
755-
flatten_vars = reduce(
756-
lambda x, y: (
757-
x + y.flatten_inner_vars()
758-
if isinstance(y, VariableBase)
759-
else x
760-
),
761-
bound_arguments.arguments.values(),
762-
[],
763743
)
764744

765-
if not any(
766-
isinstance(arg, SymbolicVariable)
767-
for arg in flatten_vars
768-
):
769-
raise BreakGraphError(
770-
InferMetaBreak(
771-
f"InferMeta encount {type(e)}, but all args are not symbolic."
772-
)
773-
)
745+
args, kwargs = map_if(
746+
(args, kwargs),
747+
pred=lambda x: x is original_var,
748+
true_fn=lambda x: replace_symbolic_var_with_constant_var(x),
749+
false_fn=lambda x: x,
750+
)
751+
else:
752+
flatten_vars = reduce(
753+
lambda x, y: (
754+
x + y.flatten_inner_vars()
755+
if isinstance(y, VariableBase)
756+
else x
757+
),
758+
bound_arguments.arguments.values(),
759+
[],
760+
)
774761

775-
args, kwargs = map_structure(
776-
replace_symbolic_var_with_constant_var, (args, kwargs)
762+
if not any(
763+
isinstance(arg, SymbolicVariable) for arg in flatten_vars
764+
):
765+
raise BreakGraphError(
766+
InferMetaBreak(
767+
f"InferMeta encountered {type(err)}, but all args are not symbolic."
768+
)
777769
)
778770

779-
metas = convert_to_meta(args)
780-
kwmetas = convert_to_meta(kwargs)
781-
return args, kwargs, infer_meta_fn(func, *metas, **kwmetas)
771+
args, kwargs = map_structure(
772+
replace_symbolic_var_with_constant_var, (args, kwargs)
773+
)
774+
return args, kwargs
782775

776+
def try_infer_meta_with_fallback_symbolic_to_constant(
777+
args, kwargs, max_retry_times=10
778+
):
779+
try:
780+
return args, kwargs, infer_meta(args, kwargs)
781+
except (NotSupportedTensorArgumentError, TypeError) as e:
782+
err = e
783+
retry_times = 0
784+
while True:
785+
retry_times += 1
786+
if retry_times >= max_retry_times:
787+
raise err
788+
try:
789+
args, kwargs = fallback_symbolic_to_constant(
790+
args, kwargs, err
791+
)
792+
return args, kwargs, infer_meta(args, kwargs)
793+
except (NotSupportedTensorArgumentError, TypeError) as e:
794+
err = e
783795
except Exception as e:
784796
if SotExtraInfo.from_exception(e).need_breakgraph:
785797
raise BreakGraphError(
@@ -790,11 +802,11 @@ def try_infer_meta_fn(args, kwargs) -> Any:
790802
raise e
791803

792804
if ENV_SOT_ALLOW_DYNAMIC_SHAPE.get():
793-
args, kwargs, out_metas = try_infer_meta_fn(args, kwargs)
805+
args, kwargs, out_metas = (
806+
try_infer_meta_with_fallback_symbolic_to_constant(args, kwargs)
807+
)
794808
else:
795-
metas = convert_to_meta(args)
796-
kwmetas = convert_to_meta(kwargs)
797-
out_metas = infer_meta_fn(func, *metas, **kwmetas)
809+
out_metas = infer_meta(args, kwargs)
798810

799811
self.collect_input_variables(list(args))
800812
self.collect_input_variables(list(kwargs.values()))

python/paddle/nn/functional/common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,10 @@ def unfold(
201201
"paddings should either be an integer or a list/tuple of 2 or 4 integers"
202202
)
203203
else:
204-
raise ValueError(
205-
"Unexpected type of paddings, it should be either an integer or a list/tuple"
206-
"of 2 or 4 integers"
204+
raise NotSupportedTensorArgumentError(
205+
"Unexpected type of paddings, it should be either an integer or a list/tuple "
206+
"of 2 or 4 integers",
207+
"paddings",
207208
)
208209

209210
if in_dynamic_or_pir_mode():

0 commit comments

Comments
 (0)