Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

FunctionGraph function refactor and infer meta support multi-outputs #111

Merged
merged 5 commits into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions symbolic_trace/infer_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from paddle.fluid.framework import Program
from paddle.utils import flatten

from .utils import Cache, Singleton, map_if, meta_str, no_eval_frame
from .utils import Cache, Singleton, map_if, meta_str


@Singleton
Expand Down Expand Up @@ -90,14 +90,8 @@ def infer_meta(self, func, *args, **kwargs):
else:
out = func(*args, **kwargs)

out = MetaInfo(
list(out.shape),
out.dtype,
out.stop_gradient,
)

paddle.disable_static()
return out
return variable_to_meta_info(out)


def convert_to_variable(args):
Expand All @@ -118,7 +112,19 @@ def convert_to_input_spec(args):
)


@no_eval_frame
def variable_to_meta_info(args):
return map_if(
args,
pred=lambda x: isinstance(x, paddle.static.Variable),
true_fn=lambda x: MetaInfo(
list(x.shape),
x.dtype,
x.stop_gradient,
),
false_fn=lambda x: x,
)


def infer_meta(func, *args, **kwargs):
return VariableCreator().infer_meta(func, *args, **kwargs)

Expand Down
101 changes: 52 additions & 49 deletions symbolic_trace/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from copy import deepcopy
from typing import Any, Callable

from ...infer_meta import InferMetaCache, infer_meta, infer_meta_for_layer
from ...infer_meta import MetaInfo, infer_meta, infer_meta_for_layer
from ...symbolic.statement_ir import Symbol
from ...symbolic.symbolic_context import SymbolicTraceContext
from ...utils import is_paddle_api, log, show_trackers
from ...utils import is_paddle_api, log, map_if, show_trackers
from .guard import Guard, StringifyExpression, make_guard
from .pycode_generator import PyCodeGen
from .tracker import DummyTracker
Expand All @@ -19,6 +19,7 @@
PaddleLayerVariable,
TensorVariable,
VariableBase,
VariableFactory,
map_variables,
topo_sort_vars,
)
Expand All @@ -35,8 +36,8 @@ def func(x):

def convert_to_symbol(inputs):
def func(x):
if isinstance(x, TensorVariable):
return Symbol(x.var_name)
if isinstance(x, (TensorVariable, PaddleLayerVariable)):
return x.get_symbol()
return x.get_value()

return map_variables(func, inputs)
Expand Down Expand Up @@ -168,77 +169,79 @@ def call_paddle_api(
):
assert is_paddle_api(func)
# not fallback api, start symbolic trace.
# TODO(xiokgun): multi-output support.
# TODO(xiokgun): may have python buildin object inside metas.
# TODO(xiokgun): 4 kinds of python arguments. support it !!
log(3, f"call paddle.api : {func.__name__}", "\n")
return self.symbolic_call(
infer_meta, self.sir_ctx.call_API, func, *args, **kwargs
)

def symbolic_call(self, infer_meta_fn, compute_fn, func, *args, **kwargs):
"""infer_meta_fn: function for infer meta, (func, metas, kwmetas) -> output_metas
compute_fn : function for sir compile, (func, input_symbols, outputs_symbols) -> None
"""
self.collect_input_variables(list(args))
self.collect_input_variables(list(kwargs.values()))
metas = convert_to_meta(args)
kwmetas = convert_to_meta(kwargs)
meta = InferMetaCache()(func, *metas, **kwmetas)
out_metas = infer_meta_fn(func, *metas, **kwmetas)
inputs_symbols = (
convert_to_symbol(args),
convert_to_symbol(kwargs),
)
log(3, f" inputs : {inputs_symbols}", "\n")
variable = TensorVariable(
meta,
self,
tracker=DummyTracker(list(args) + list(kwargs.values())),
outputs = map_if(
out_metas,
pred=lambda x: isinstance(x, MetaInfo),
true_fn=lambda x: TensorVariable(
x,
self,
tracker=DummyTracker(list(args) + list(kwargs.values())),
),
false_fn=lambda x: x,
)
self.sir_ctx.call_API(
func,
inputs=inputs_symbols,
outputs=convert_to_symbol(variable),
compute_fn(
func, inputs_symbols, convert_to_symbol(outputs)
) # symbolic only contain symbols.

self._put_inner(variable)
return variable
self._put_inner(outputs)
return VariableFactory.from_value(outputs, self, DummyTracker(outputs))

def call_tensor_method(self, method_name: str, *args: VariableBase):
self.collect_input_variables(list(args))
metas = convert_to_meta(args)
meta = infer_meta(method_name, *metas)
variable = TensorVariable(meta, self, tracker=DummyTracker(list(args)))
self.sir_ctx.call_METHOD(
method_name,
inputs=(convert_to_symbol(args), {}),
outputs=convert_to_symbol(variable),
) # symbolic only contain symbols.
self._put_inner(variable)
return variable
return self.symbolic_call(
infer_meta, self.sir_ctx.call_METHOD, method_name, *args
)

def call_layer(
self,
layer: PaddleLayerVariable,
*args: VariableBase,
**kwargs: VariableBase,
):
self.collect_input_variables([layer, *args])
self.collect_input_variables(list(kwargs.values()))
metas = convert_to_meta(args)
kwmetas = convert_to_meta(kwargs)
meta = infer_meta_for_layer(layer.value, *metas, **kwmetas)
inputs_symbols = (
(layer.get_symbol(), *convert_to_symbol(args)),
convert_to_symbol(kwargs),
)
variable = TensorVariable(
meta,
self,
tracker=DummyTracker([layer, *args] + list(kwargs.values())),
)
self.sir_ctx.call_LAYER(
layer.value.__class__.__name__,
inputs=inputs_symbols,
outputs=convert_to_symbol(variable),
def infer_meta_fn(layer, *metas, **kwmetas):
metas = metas[1:]
metas = infer_meta_for_layer(layer.value, *metas, **kwmetas)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种有必要的地方wrapper一层是没有问题的

return metas

def compute_fn(layer, inputs, outputs):
inputs = (layer.get_symbol(), *inputs)
inputs = inputs[1:]
self.sir_ctx.call_LAYER(
layer.value.__class__.__name__,
inputs=inputs,
outputs=outputs,
)

return self.symbolic_call(
infer_meta_fn, compute_fn, layer, *[layer, *args]
)
self._put_inner(variable)
return variable

def _put_inner(self, var):
self.inner_out.add(var.id)
map_if(
var,
pred=lambda x: isinstance(x, VariableBase),
true_fn=lambda x: self.inner_out.add(x.id),
false_fn=lambda x: None,
)

def add_global_guarded_variable(self, variable: VariableBase):
self._global_guarded_variables.append(variable)
Expand Down