diff --git a/symbolic_trace/infer_meta.py b/symbolic_trace/infer_meta.py index 86f68add8..0a13dd8fe 100644 --- a/symbolic_trace/infer_meta.py +++ b/symbolic_trace/infer_meta.py @@ -2,7 +2,7 @@ from paddle.fluid.framework import Program from paddle.utils import flatten -from .utils import Cache, Singleton, map_if, meta_str, no_eval_frame +from .utils import Cache, Singleton, map_if, meta_str @Singleton @@ -90,14 +90,8 @@ def infer_meta(self, func, *args, **kwargs): else: out = func(*args, **kwargs) - out = MetaInfo( - list(out.shape), - out.dtype, - out.stop_gradient, - ) - paddle.disable_static() - return out + return variable_to_meta_info(out) def convert_to_variable(args): @@ -118,7 +112,19 @@ def convert_to_input_spec(args): ) -@no_eval_frame +def variable_to_meta_info(args): + return map_if( + args, + pred=lambda x: isinstance(x, paddle.static.Variable), + true_fn=lambda x: MetaInfo( + list(x.shape), + x.dtype, + x.stop_gradient, + ), + false_fn=lambda x: x, + ) + + def infer_meta(func, *args, **kwargs): return VariableCreator().infer_meta(func, *args, **kwargs) diff --git a/symbolic_trace/opcode_translator/executor/function_graph.py b/symbolic_trace/opcode_translator/executor/function_graph.py index 35daa6028..8f1521aa5 100644 --- a/symbolic_trace/opcode_translator/executor/function_graph.py +++ b/symbolic_trace/opcode_translator/executor/function_graph.py @@ -7,10 +7,10 @@ from copy import deepcopy from typing import Any, Callable -from ...infer_meta import InferMetaCache, infer_meta, infer_meta_for_layer +from ...infer_meta import MetaInfo, infer_meta, infer_meta_for_layer from ...symbolic.statement_ir import Symbol from ...symbolic.symbolic_context import SymbolicTraceContext -from ...utils import is_paddle_api, log, show_trackers +from ...utils import is_paddle_api, log, map_if, show_trackers from .guard import Guard, StringifyExpression, make_guard from .pycode_generator import PyCodeGen from .tracker import DummyTracker @@ -19,6 +19,7 @@ PaddleLayerVariable, TensorVariable, VariableBase, + VariableFactory, map_variables, topo_sort_vars, ) @@ -35,8 +36,8 @@ def func(x): def convert_to_symbol(inputs): def func(x): - if isinstance(x, TensorVariable): - return Symbol(x.var_name) + if isinstance(x, (TensorVariable, PaddleLayerVariable)): + return x.get_symbol() return x.get_value() return map_variables(func, inputs) @@ -168,46 +169,47 @@ def call_paddle_api( ): assert is_paddle_api(func) # not fallback api, start symbolic trace. - # TODO(xiokgun): multi-output support. # TODO(xiokgun): may have python buildin object inside metas. # TODO(xiokgun): 4 kinds of python arguments. support it !! log(3, f"call paddle.api : {func.__name__}", "\n") + return self.symbolic_call( + infer_meta, self.sir_ctx.call_API, func, *args, **kwargs + ) + + def symbolic_call(self, infer_meta_fn, compute_fn, func, *args, **kwargs): + """infer_meta_fn: function for infer meta, (func, metas, kwmetas) -> output_metas + compute_fn : function for sir compile, (func, input_symbols, outputs_symbols) -> None + """ self.collect_input_variables(list(args)) self.collect_input_variables(list(kwargs.values())) metas = convert_to_meta(args) kwmetas = convert_to_meta(kwargs) - meta = InferMetaCache()(func, *metas, **kwmetas) + out_metas = infer_meta_fn(func, *metas, **kwmetas) inputs_symbols = ( convert_to_symbol(args), convert_to_symbol(kwargs), ) log(3, f" inputs : {inputs_symbols}", "\n") - variable = TensorVariable( - meta, - self, - tracker=DummyTracker(list(args) + list(kwargs.values())), + outputs = map_if( + out_metas, + pred=lambda x: isinstance(x, MetaInfo), + true_fn=lambda x: TensorVariable( + x, + self, + tracker=DummyTracker(list(args) + list(kwargs.values())), + ), + false_fn=lambda x: x, ) - self.sir_ctx.call_API( - func, - inputs=inputs_symbols, - outputs=convert_to_symbol(variable), + compute_fn( + func, inputs_symbols, convert_to_symbol(outputs) ) # symbolic only contain symbols. - - self._put_inner(variable) - return variable + self._put_inner(outputs) + return VariableFactory.from_value(outputs, self, DummyTracker(outputs)) def call_tensor_method(self, method_name: str, *args: VariableBase): - self.collect_input_variables(list(args)) - metas = convert_to_meta(args) - meta = infer_meta(method_name, *metas) - variable = TensorVariable(meta, self, tracker=DummyTracker(list(args))) - self.sir_ctx.call_METHOD( - method_name, - inputs=(convert_to_symbol(args), {}), - outputs=convert_to_symbol(variable), - ) # symbolic only contain symbols. - self._put_inner(variable) - return variable + return self.symbolic_call( + infer_meta, self.sir_ctx.call_METHOD, method_name, *args + ) def call_layer( self, @@ -215,30 +217,31 @@ def call_layer( *args: VariableBase, **kwargs: VariableBase, ): - self.collect_input_variables([layer, *args]) - self.collect_input_variables(list(kwargs.values())) - metas = convert_to_meta(args) - kwmetas = convert_to_meta(kwargs) - meta = infer_meta_for_layer(layer.value, *metas, **kwmetas) - inputs_symbols = ( - (layer.get_symbol(), *convert_to_symbol(args)), - convert_to_symbol(kwargs), - ) - variable = TensorVariable( - meta, - self, - tracker=DummyTracker([layer, *args] + list(kwargs.values())), - ) - self.sir_ctx.call_LAYER( - layer.value.__class__.__name__, - inputs=inputs_symbols, - outputs=convert_to_symbol(variable), + def infer_meta_fn(layer, *metas, **kwmetas): + metas = metas[1:] + metas = infer_meta_for_layer(layer.value, *metas, **kwmetas) + return metas + + def compute_fn(layer, inputs, outputs): + inputs = (layer.get_symbol(), *inputs) + inputs = inputs[1:] + self.sir_ctx.call_LAYER( + layer.value.__class__.__name__, + inputs=inputs, + outputs=outputs, + ) + + return self.symbolic_call( + infer_meta_fn, compute_fn, layer, *[layer, *args] ) - self._put_inner(variable) - return variable def _put_inner(self, var): - self.inner_out.add(var.id) + map_if( + var, + pred=lambda x: isinstance(x, VariableBase), + true_fn=lambda x: self.inner_out.add(x.id), + false_fn=lambda x: None, + ) def add_global_guarded_variable(self, variable: VariableBase): self._global_guarded_variables.append(variable)