diff --git a/sot/opcode_translator/executor/function_graph.py b/sot/opcode_translator/executor/function_graph.py index 84da603ae..0ecc57c5a 100644 --- a/sot/opcode_translator/executor/function_graph.py +++ b/sot/opcode_translator/executor/function_graph.py @@ -301,7 +301,7 @@ def start_compile(self, *ret_vars: VariableBase): found = False for variable in self.input_variables: if ( - isinstance(variable, (TensorVariable, PaddleLayerVariable)) + isinstance(variable, TensorVariable) and variable.get_symbol().name == name ): variable.tracker.gen_instructions(self.pycode_gen) @@ -426,15 +426,12 @@ def call_layer( """ def infer_meta_fn(layer, *metas, **kwmetas): - metas = metas[1:] metas = LayerInferMetaCache()(layer.value, *metas, **kwmetas) return metas def compute_fn(layer, inputs, outputs, stacks): - inputs = (layer.get_symbol(), *inputs) - inputs = inputs[1:] self.sir_ctx.call_LAYER( - layer.value.__class__.__name__, + layer.value, inputs=inputs, outputs=outputs, stacks=stacks, @@ -444,7 +441,7 @@ def message_handler(*args, **kwargs): return f"Call paddle layer error: {layer}, may be not a valid paddle layer ?" return inner_error_default_handler(self.symbolic_call, message_handler)( - infer_meta_fn, compute_fn, layer, *[layer, *args], **kwargs + infer_meta_fn, compute_fn, layer, *args, **kwargs ) @event_register("symbolic_call", event_level=2) diff --git a/sot/opcode_translator/executor/variables/callable.py b/sot/opcode_translator/executor/variables/callable.py index 59dc8c98f..4d9c7fc66 100644 --- a/sot/opcode_translator/executor/variables/callable.py +++ b/sot/opcode_translator/executor/variables/callable.py @@ -9,10 +9,8 @@ import paddle from .... import psdb -from ....symbolic.statement_ir import Symbol from ....utils import ( EventGuard, - NameGenerator, is_break_graph_api, is_break_graph_tensor_methods, is_builtin_fn, @@ -503,18 +501,13 @@ class PaddleLayerVariable(LayerVariable): tracker(Tracker): The Tracker object that tracks the information of this variable. """ - layer_name_generator = NameGenerator("layer_") - def __init__( self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker ): super().__init__(layer, graph, tracker) - self.name = self.layer_name_generator.next() - - def get_symbol(self) -> Symbol: - return Symbol(self.name) def call_function(self, /, *args, **kwargs): + self.graph.add_global_guarded_variable(self) return self.graph.call_layer(self, *args, **kwargs) def make_stringify_guard(self) -> list[StringifyExpression]: diff --git a/sot/symbolic/interpreter.py b/sot/symbolic/interpreter.py index a70c484b2..da057c53b 100644 --- a/sot/symbolic/interpreter.py +++ b/sot/symbolic/interpreter.py @@ -120,27 +120,25 @@ def _set(v, s): return replace_symbol(SIR.outputs, state) def call(self, stmt: Statement, inputs): - SIR = self.get_sir(stmt.name) + SIR = self.get_sir(stmt.sir_name) state = prepare_state(SIR, inputs) - return self.run_sir(stmt.name, state) + return self.run_sir(stmt.sir_name, state) def api(self, stmt, inputs): args, kwargs = inputs - return stmt.name(*args, **kwargs) + return stmt.api(*args, **kwargs) def method(self, stmt, inputs): args, kwargs = inputs var = args[0] - return getattr(var, stmt.name)(*args[1:], **kwargs) + return getattr(var, stmt.method)(*args[1:], **kwargs) def layer(self, stmt, inputs): args, kwargs = inputs - layer, args = args[0], args[1:] + layer = stmt.layer() + assert layer is not None, "SIR bound layer is None." return layer(*args, **kwargs) - def delete(self, stmt, inputs): - pass - def compile_sir(context: SymbolicTraceContext, name: str): """ diff --git a/sot/symbolic/statement_ir.py b/sot/symbolic/statement_ir.py index 5901fb198..542eb71d9 100644 --- a/sot/symbolic/statement_ir.py +++ b/sot/symbolic/statement_ir.py @@ -5,6 +5,10 @@ """ from __future__ import annotations +import weakref +from typing import Callable + +import paddle from paddle.utils import is_sequence, map_structure from ..utils import NameGenerator, OrderedSet, Singleton, flatten_extend @@ -69,15 +73,10 @@ def to_string(inps): inps = (x.__str__() for x in inps) return ", ".join(inps) - name = ( - self.name - if isinstance(self.name, str) - else "paddle." + self.name.__name__ - ) return "{} || {} = {} ({}) ".format( self.type + " " * (10 - len(self.type)), to_string(self.outputs), - name, + self.name, to_string(self.inputs), ) @@ -85,6 +84,58 @@ def __repr__(self): return self.__str__() +class CallStatement(Statement): + def __init__( + self, + name: str, + inputs: list[Symbol], + outputs: list[Symbol], + stacks: list[str], + ): + super().__init__("call", name, inputs, outputs, stacks) + self.sir_name = name + + +class ApiStatement(Statement): + def __init__( + self, + api: Callable, + inputs: list[Symbol], + outputs: list[Symbol], + stacks: list[str], + ): + super().__init__( + "api", "paddle." + api.__name__, inputs, outputs, stacks + ) + self.api = api + + +class MethodStatement(Statement): + def __init__( + self, + name: str, + inputs: list[Symbol], + outputs: list[Symbol], + stacks: list[str], + ): + super().__init__("method", name, inputs, outputs, stacks) + self.method = name + + +class LayerStatement(Statement): + def __init__( + self, + layer: paddle.nn.Layer, + inputs: list[Symbol], + outputs: list[Symbol], + stacks: list[str], + ): + super().__init__( + "layer", layer.__class__.__name__, inputs, outputs, stacks + ) + self.layer = weakref.ref(layer) + + class StatementIR: """ StatementIR is the carrier that records the code for building the neural network model.It is diff --git a/sot/symbolic/symbolic_context.py b/sot/symbolic/symbolic_context.py index 0feffa6e5..9b8509d9e 100644 --- a/sot/symbolic/symbolic_context.py +++ b/sot/symbolic/symbolic_context.py @@ -2,7 +2,15 @@ from ..utils import event_register, log from .compile_cache import CompileSIRCache -from .statement_ir import Statement, StatementIR, StatementIRFactory, Symbol +from .statement_ir import ( + ApiStatement, + CallStatement, + LayerStatement, + MethodStatement, + StatementIR, + StatementIRFactory, + Symbol, +) class SymbolicTraceContext: @@ -41,7 +49,7 @@ def call_SIR(self, sirname, inputs, outputs, stacks): Call a SIR, which is a subgraph. """ - stmt = Statement("call", sirname, inputs, outputs, stacks) + stmt = CallStatement(sirname, inputs, outputs, stacks) self.TOS.add_statement(stmt) @event_register("call_API", event_level=2) @@ -51,7 +59,7 @@ def call_API(self, api, inputs, outputs, stacks): """ assert callable(api), "call_API must receive a paddle api." - stmt = Statement("api", api, inputs, outputs, stacks) + stmt = ApiStatement(api, inputs, outputs, stacks) self.TOS.add_statement(stmt) @event_register("call_METHOD", event_level=2) @@ -65,15 +73,15 @@ def call_METHOD(self, method_name, inputs, outputs, stacks): assert isinstance( inputs[0][0], Symbol ), "call_METHOD must first augument must be Symbol Variable." - stmt = Statement("method", method_name, inputs, outputs, stacks) + stmt = MethodStatement(method_name, inputs, outputs, stacks) self.TOS.add_statement(stmt) @event_register("call_LAYER", event_level=2) - def call_LAYER(self, layer_name, inputs, outputs, stacks): + def call_LAYER(self, layer, inputs, outputs, stacks): """ Call a layer of a api. """ - stmt = Statement("layer", layer_name, inputs, outputs, stacks) + stmt = LayerStatement(layer, inputs, outputs, stacks) self.TOS.add_statement(stmt) def get_sir(self, name: str):